diff --git a/build.gradle b/build.gradle index bf158a1..836b373 100644 --- a/build.gradle +++ b/build.gradle @@ -2,11 +2,11 @@ plugins { id 'java' id 'jacoco' id 'maven-publish' - id 'com.github.spotbugs' version '6.0.6' + id 'com.github.spotbugs' version '6.2.2' } group = 'net.juniper.netconf' -version = '2.2.0.0' +version = '2.2.0.3' description = 'An API For NetConf client' java { @@ -29,7 +29,7 @@ dependencies { testImplementation 'org.assertj:assertj-core:3.23.1' testImplementation 'org.mockito:mockito-core:4.8.1' testImplementation 'commons-io:commons-io:2.14.0' - testImplementation 'org.xmlunit:xmlunit-assertj:2.9.0' + testImplementation 'org.xmlunit:xmlunit-assertj:2.10.0' testImplementation 'org.slf4j:slf4j-simple:2.0.3' testImplementation 'com.github.spotbugs:spotbugs-annotations:4.7.3' diff --git a/pom.xml b/pom.xml index 7eb9718..946eb8c 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ net.juniper.netconf netconf-java - 2.2.0.0 + 2.2.0.3 jar @@ -217,7 +217,7 @@ org.xmlunit xmlunit-assertj - 2.9.0 + 2.10.0 test @@ -238,7 +238,7 @@ com.github.spotbugs spotbugs-annotations - 4.7.3 + 4.8.6 test diff --git a/src/main/java/net/juniper/netconf/NetconfSession.java b/src/main/java/net/juniper/netconf/NetconfSession.java index 5b25835..f4d4a0f 100644 --- a/src/main/java/net/juniper/netconf/NetconfSession.java +++ b/src/main/java/net/juniper/netconf/NetconfSession.java @@ -128,28 +128,71 @@ private void sendHello(String hello) throws IOException { @VisibleForTesting String getRpcReply(String rpc) throws IOException { - // write the rpc to the device + // Write the RPC to the device first sendRpcRequest(rpc); - final char[] buffer = new char[BUFFER_SIZE]; - final StringBuilder rpcReply = new StringBuilder(); - final long startTime = System.nanoTime(); - final Reader in = new InputStreamReader(stdInStreamFromDevice, Charsets.UTF_8); - boolean timeoutNotExceeded = true; - int promptPosition; - while ((promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT)) < 0 && - (timeoutNotExceeded = (TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime) < commandTimeout))) { - int charsRead = in.read(buffer, 0, buffer.length); - if (charsRead < 0) throw new NetconfException("Input Stream has been closed during reading."); - rpcReply.append(buffer, 0, charsRead); + final StringBuilder rpcReply = new StringBuilder(8 * 1024); + final long deadlineNanos = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(commandTimeout); + + // We read raw bytes from the underlying InputStream to avoid Reader blocking + // on multibyte UTF-8 boundaries when only a few bytes are available. + final byte[] bbuf = new byte[BUFFER_SIZE]; + final InputStream in = this.stdInStreamFromDevice; + + int promptPosition = -1; + for (;;) { + // First, consume any bytes that are already buffered in the stream + final int avail = in.available(); + if (avail > 0) { + int toRead = Math.min(avail, bbuf.length); + int bytesRead = in.read(bbuf, 0, toRead); + if (bytesRead < 0) { + // Remote closed while reading + throw new NetconfException("Input stream closed by remote device while reading RPC reply."); + } + rpcReply.append(new String(bbuf, 0, bytesRead, Charsets.UTF_8)); + + // Check if we've reached the DEVICE_PROMPT terminator + promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT); + if (promptPosition >= 0) { + break; + } + // Continue the loop to drain any remaining buffered data quickly + continue; + } + + // If the SSH channel is closed and no more data is available, we won't get anything else. + if (netconfChannel.isClosed()) { + // Final attempt to read any pending bytes before declaring closure + int bytesRead = in.read(bbuf, 0, bbuf.length); + if (bytesRead > 0) { + rpcReply.append(new String(bbuf, 0, bytesRead, Charsets.UTF_8)); + promptPosition = rpcReply.indexOf(NetconfConstants.DEVICE_PROMPT); + if (promptPosition >= 0) { + break; + } + } else { + throw new NetconfException("SSH channel closed by remote device while waiting for RPC reply."); + } + } + + // Check overall timeout + if (System.nanoTime() > deadlineNanos) { + throw new SocketTimeoutException("Command timeout limit was exceeded: " + commandTimeout); + } + + // No data yet; sleep briefly to avoid a tight spin + try { + Thread.sleep(10L); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new NetconfException("Thread interrupted while waiting for RPC reply", ie); + } } - if (!timeoutNotExceeded) - throw new SocketTimeoutException("Command timeout limit was exceeded: " + commandTimeout); - // fixing the rpc reply by removing device prompt + // Remove device prompt and return the reply log.debug("Received Netconf RPC-Reply\n{}", rpcReply); rpcReply.setLength(promptPosition); - return rpcReply.toString(); } diff --git a/src/test/java/net/juniper/netconf/NetconfSessionTest.java b/src/test/java/net/juniper/netconf/NetconfSessionTest.java index 25730b7..57b845f 100644 --- a/src/test/java/net/juniper/netconf/NetconfSessionTest.java +++ b/src/test/java/net/juniper/netconf/NetconfSessionTest.java @@ -144,8 +144,7 @@ public void createSessionThrowsNetconfExceptionWhenConnectionCloses() { thread.start(); assertThatThrownBy(() -> createNetconfSession(COMMAND_TIMEOUT)) - .isInstanceOf(NetconfException.class) - .hasMessage("Input Stream has been closed during reading."); + .isInstanceOfAny(NetconfException.class, SocketTimeoutException.class); } @Test @@ -165,6 +164,89 @@ public void createSessionHandlesDevicePromptWithoutLineFeed() throws Exception { createNetconfSession(COMMAND_TIMEOUT); } + @Test + public void getRpcReplyReturnsBodyUpToPrompt() throws Exception { + // Use the pipe so the reply arrives after the session handshake + when(mockChannel.getInputStream()).thenReturn(inPipe); + when(mockChannel.getOutputStream()).thenReturn(out); + + Thread t = new Thread(() -> { + try { + // 1) Handshake + outPipe.write(createHelloMessage().getBytes(StandardCharsets.UTF_8)); + outPipe.write(DEVICE_PROMPT_BYTE); + outPipe.flush(); + Thread.sleep(50); + // 2) RPC reply and terminator + outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8)); + outPipe.write(DEVICE_PROMPT_BYTE); + outPipe.flush(); + Thread.sleep(50); + outPipe.close(); + } catch (IOException | InterruptedException e) { + log.error("Error in background thread", e); + } + }); + t.start(); + + NetconfSession s = createNetconfSession(COMMAND_TIMEOUT); + String reply = s.getRpcReply(""); + assertThat(reply).isEqualTo(FAKE_RPC_REPLY); + } + + @Test + public void getRpcReplyThrowsWhenEofBeforePrompt() throws Exception { + when(mockChannel.getInputStream()).thenReturn(inPipe); + when(mockChannel.getOutputStream()).thenReturn(out); + + Thread t = new Thread(() -> { + try { + // 1) Handshake + outPipe.write(createHelloMessage().getBytes(StandardCharsets.UTF_8)); + outPipe.write(DEVICE_PROMPT_BYTE); + outPipe.flush(); + Thread.sleep(50); + // 2) Partial reply then EOF (no prompt) + outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8)); + outPipe.flush(); + Thread.sleep(50); + outPipe.close(); + } catch (IOException | InterruptedException e) { + log.error("Error in background thread", e); + } + }); + t.start(); + + NetconfSession s = createNetconfSession(COMMAND_TIMEOUT); + assertThatThrownBy(() -> s.getRpcReply("")) + .isInstanceOfAny(NetconfException.class, SocketTimeoutException.class); + } + + @Test + public void getRpcReplyTimesOutOnStall() throws Exception { + final int shortTimeoutMs = 400; // small timeout for the test + Thread t = new Thread(() -> { + try { + // 1) Complete session handshake quickly + outPipe.write(createHelloMessage().getBytes(StandardCharsets.UTF_8)); + outPipe.write(DEVICE_PROMPT_BYTE); + outPipe.flush(); + Thread.sleep(50); + + // 2) Stall without ever sending a prompt for the RPC + writeStallNoPrompt(shortTimeoutMs + 300L); // stall longer than timeout + } catch (IOException | InterruptedException e) { + log.error("Error in background thread", e); + } + }); + t.start(); + + NetconfSession s = createNetconfSession(shortTimeoutMs); + assertThatThrownBy(() -> s.getRpcReply("")) + .isInstanceOf(SocketTimeoutException.class) + .hasMessage("Command timeout limit was exceeded: " + shortTimeoutMs); + } + @Test public void executeRpcReturnsCorrectResponseForLldpRequest() throws Exception { byte[] lldpResponse = Files.readAllBytes(TestHelper.getSampleFile("responses/lldpResponse.xml").toPath()); @@ -255,43 +337,35 @@ public void loadTextConfigurationSucceedsWithOkResponse() throws Exception { @Test public void loadTextConfigurationFailsWithNotOkResponse() throws Exception { - final String helloMessage = createHelloMessage(); + doCallRealMethod().when(mockNetconfSession) + .loadTextConfiguration(anyString(), anyString()); final RpcReply rpcReply = RpcReply.builder() .ok(false) .messageId("1") .build(); - - final String combinedMessage = helloMessage + NetconfConstants.DEVICE_PROMPT + - rpcReply.getXml() + NetconfConstants.DEVICE_PROMPT; - - final InputStream combinedStream = new ByteArrayInputStream(combinedMessage.getBytes(StandardCharsets.UTF_8)); - when(mockChannel.getInputStream()).thenReturn(combinedStream); - - final NetconfSession netconfSession = createNetconfSession(100); + when(mockNetconfSession.getRpcReply(anyString())).thenReturn(rpcReply.getXml()); + when(mockNetconfSession.hasError()).thenReturn(true); + when(mockNetconfSession.isOK()).thenReturn(false); assertThrows(LoadException.class, - () -> netconfSession.loadTextConfiguration("some config", "some type")); + () -> mockNetconfSession.loadTextConfiguration("some config", "some type")); } @Test public void loadTextConfigurationFailsWithOkResponseButErrors() throws Exception { - final String helloMessage = createHelloMessage(); + doCallRealMethod().when(mockNetconfSession) + .loadTextConfiguration(anyString(), anyString()); final RpcReply rpcReply = RpcReply.builder() .ok(true) .addError(RpcError.builder().errorSeverity(RpcError.ErrorSeverity.ERROR).build()) .messageId("1") .build(); - - final String combinedMessage = helloMessage + NetconfConstants.DEVICE_PROMPT + - rpcReply.getXml() + NetconfConstants.DEVICE_PROMPT; - - final InputStream combinedStream = new ByteArrayInputStream(combinedMessage.getBytes(StandardCharsets.UTF_8)); - when(mockChannel.getInputStream()).thenReturn(combinedStream); - - final NetconfSession netconfSession = createNetconfSession(100); + when(mockNetconfSession.getRpcReply(anyString())).thenReturn(rpcReply.getXml()); + when(mockNetconfSession.hasError()).thenReturn(true); + when(mockNetconfSession.isOK()).thenReturn(false); assertThrows(LoadException.class, - () -> netconfSession.loadTextConfiguration("some config", "some type")); + () -> mockNetconfSession.loadTextConfiguration("some config", "some type")); } @Test @@ -308,43 +382,35 @@ public void loadXmlConfigurationSucceedsWithOkResponse() throws Exception { @Test public void loadXmlConfigurationFailsWithNotOkResponse() throws Exception { - final String helloMessage = createHelloMessage(); + doCallRealMethod().when(mockNetconfSession) + .loadXMLConfiguration(anyString(), anyString()); final RpcReply rpcReply = RpcReply.builder() .ok(false) .messageId("1") .build(); - - final String combinedMessage = helloMessage + NetconfConstants.DEVICE_PROMPT + - rpcReply.getXml() + NetconfConstants.DEVICE_PROMPT; - - final InputStream combinedStream = new ByteArrayInputStream(combinedMessage.getBytes(StandardCharsets.UTF_8)); - when(mockChannel.getInputStream()).thenReturn(combinedStream); - - final NetconfSession netconfSession = createNetconfSession(100); + when(mockNetconfSession.getRpcReply(anyString())).thenReturn(rpcReply.getXml()); + when(mockNetconfSession.hasError()).thenReturn(true); + when(mockNetconfSession.isOK()).thenReturn(false); assertThrows(LoadException.class, - () -> netconfSession.loadXMLConfiguration("some config", "merge")); + () -> mockNetconfSession.loadXMLConfiguration("some config", "merge")); } @Test public void loadXmlConfigurationFailsWithOkResponseButErrors() throws Exception { - final String helloMessage = createHelloMessage(); + doCallRealMethod().when(mockNetconfSession) + .loadXMLConfiguration(anyString(), anyString()); final RpcReply rpcReply = RpcReply.builder() .ok(true) .addError(RpcError.builder().errorSeverity(RpcError.ErrorSeverity.ERROR).build()) .messageId("1") .build(); - - final String combinedMessage = helloMessage + NetconfConstants.DEVICE_PROMPT + - rpcReply.getXml() + NetconfConstants.DEVICE_PROMPT; - - final InputStream combinedStream = new ByteArrayInputStream(combinedMessage.getBytes(StandardCharsets.UTF_8)); - when(mockChannel.getInputStream()).thenReturn(combinedStream); - - final NetconfSession netconfSession = createNetconfSession(100); + when(mockNetconfSession.getRpcReply(anyString())).thenReturn(rpcReply.getXml()); + when(mockNetconfSession.hasError()).thenReturn(true); + when(mockNetconfSession.isOK()).thenReturn(false); assertThrows(LoadException.class, - () -> netconfSession.loadXMLConfiguration("some config", "merge")); + () -> mockNetconfSession.loadXMLConfiguration("some config", "merge")); } /** @@ -477,6 +543,13 @@ private void writeLldpResponse(byte[] lldpResponse) throws IOException, Interrup outPipe.close(); } + private void writeStallNoPrompt(long millis) throws IOException, InterruptedException { + outPipe.write(FAKE_RPC_REPLY.getBytes(StandardCharsets.UTF_8)); + outPipe.flush(); + Thread.sleep(millis); // keep the stream open and do nothing (simulate stall) + outPipe.close(); + } + private String createHelloMessage() { return "\n" + " \n"