diff --git a/modules/cpr/src/main/java/org/atmosphere/container/JSR356Endpoint.java b/modules/cpr/src/main/java/org/atmosphere/container/JSR356Endpoint.java index 537aa1e012..e201431bae 100644 --- a/modules/cpr/src/main/java/org/atmosphere/container/JSR356Endpoint.java +++ b/modules/cpr/src/main/java/org/atmosphere/container/JSR356Endpoint.java @@ -225,11 +225,13 @@ public void onOpen(Session session, final EndpointConfig endpointConfig) { cookies.addAll(CookieUtil.ServerCookieDecoder.STRICT.decode(cookieHeader)); } - Enumeration attributeNames = handshakeSession.getAttributeNames(); - Map attributes = new ConcurrentHashMap<>(); - while (attributeNames.hasMoreElements()) { - String attributeName = attributeNames.nextElement(); - attributes.put(attributeName, handshakeSession.getAttribute(attributeName)); + final Map attributes = new ConcurrentHashMap<>(); + if (handshakeSession != null) { + Enumeration attributeNames = handshakeSession.getAttributeNames(); + while (attributeNames.hasMoreElements()) { + String attributeName = attributeNames.nextElement(); + attributes.put(attributeName, handshakeSession.getAttribute(attributeName)); + } } request = new AtmosphereRequestImpl.Builder() diff --git a/modules/cpr/src/test/java/org/atmosphere/container/version/JSR356WebSocketTest.java b/modules/cpr/src/test/java/org/atmosphere/container/version/JSR356WebSocketTest.java index a6bfeda700..9772863553 100644 --- a/modules/cpr/src/test/java/org/atmosphere/container/version/JSR356WebSocketTest.java +++ b/modules/cpr/src/test/java/org/atmosphere/container/version/JSR356WebSocketTest.java @@ -23,29 +23,34 @@ import jakarta.websocket.SendResult; import jakarta.websocket.Session; import jakarta.websocket.server.HandshakeRequest; +import jakarta.websocket.server.ServerEndpointConfig; import org.atmosphere.container.JSR356Endpoint; -import org.atmosphere.cpr.ApplicationConfig; import org.atmosphere.cpr.AtmosphereConfig; import org.atmosphere.cpr.AtmosphereFramework; import org.atmosphere.cpr.AtmosphereRequest; import org.atmosphere.websocket.WebSocketProcessor; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import java.io.IOException; import java.lang.reflect.Field; import java.net.URI; +import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; import java.util.Map; import java.util.Vector; +import static org.atmosphere.cpr.ApplicationConfig.JSR356_MAPPING_PATH; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -157,10 +162,9 @@ public void testAttributePropagationFromHandshakeSessionToAtmosphereRequest() th AtmosphereFramework mockFramework = mock(AtmosphereFramework.class); AtmosphereConfig mockConfig = mock(AtmosphereConfig.class); - // Stub the getAtmosphereConfig and getInitParameter methods when(mockFramework.getAtmosphereConfig()).thenReturn(mockConfig); when(mockFramework.getServletContext()).thenReturn(mock(ServletContext.class)); - when(mockConfig.getInitParameter(ApplicationConfig.JSR356_MAPPING_PATH)).thenReturn("/"); + when(mockConfig.getInitParameter(JSR356_MAPPING_PATH)).thenReturn("/"); when(mockConfig.getServletContext()).thenReturn(mock(ServletContext.class)); WebSocketProcessor webSocketProcessor = mock(WebSocketProcessor.class); @@ -180,4 +184,42 @@ public void testAttributePropagationFromHandshakeSessionToAtmosphereRequest() th "Attribute value should match the value from the HttpSession"); } } + + @Test + public void testOnOpenWithNullHandshakeSession() throws IOException { + Session mockSession = Mockito.mock(Session.class); + HandshakeRequest mockHandshakeRequest = Mockito.mock(HandshakeRequest.class); + ServerEndpointConfig mockConfig = Mockito.mock(ServerEndpointConfig.class); + WebSocketProcessor mockProcessor = Mockito.mock(WebSocketProcessor.class); + RemoteEndpoint.Async mockAsyncRemote = Mockito.mock(RemoteEndpoint.Async.class); + + AtmosphereFramework mockFramework = mock(AtmosphereFramework.class); + AtmosphereConfig mockAtmosphereConfig = mock(AtmosphereConfig.class); + + when(mockProcessor.handshake(Mockito.any(AtmosphereRequest.class))).thenReturn(true); + when(mockFramework.getAtmosphereConfig()).thenReturn(mockAtmosphereConfig); + when(mockFramework.getServletContext()).thenReturn(mock(ServletContext.class)); + when(mockAtmosphereConfig.getInitParameter(JSR356_MAPPING_PATH)).thenReturn("/"); + when(mockAtmosphereConfig.getServletContext()).thenReturn(mock(ServletContext.class)); + when(mockFramework.getAtmosphereConfig()).thenReturn(mockAtmosphereConfig); + when(mockFramework.getServletContext()).thenReturn(mock(ServletContext.class)); + when(mockAtmosphereConfig.getInitParameter(JSR356_MAPPING_PATH)).thenReturn("/"); + when(mockAtmosphereConfig.getServletContext()).thenReturn(mock(ServletContext.class)); + + JSR356Endpoint endpoint = new JSR356Endpoint(mockFramework, mockProcessor); + + when(mockSession.getAsyncRemote()).thenReturn(mockAsyncRemote); + when(mockSession.isOpen()).thenReturn(true); + when(mockSession.getRequestURI()).thenReturn(URI.create("/")); + + when(mockHandshakeRequest.getHttpSession()).thenReturn(null); + when(mockHandshakeRequest.getHeaders()).thenReturn(Collections.emptyMap()); + + endpoint.handshakeRequest(mockHandshakeRequest); + + endpoint.onOpen(mockSession, mockConfig); + + verify(mockSession, never()).close(Mockito.any()); + + } } \ No newline at end of file