4

I would like to authenticate WebSocket with spring security. From spring official document 23.2 WebSocket Authentication, WebSocket will reuse the same authentication information that is found in the HTTP request when the WebSocket connection is made. So I setup spring security to authenticate rest service. If the user passes the rest authentication, it will have permission to WebSocket connection, otherwise it can't establish WebSocket connection. Following is the code:

Rest service for login: WssAuthService.java

import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

/**
 * Authentication service.
 */
@RestController
@RequestMapping(path = "/dm")
public class WssAuthService {
    
    @RequestMapping(path = "/login", method = RequestMethod.GET)
    public String login(){
        return "Login success to WssBroker...";
    }
}

spring security configuration: WebSecurityConfig.java

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.http.SessionCreationPolicy;
 
@Configuration
@EnableWebSecurity
public class WebSecurityConfig extends WebSecurityConfigurerAdapter {
 
    public final static String REALM="MY_TEST_REALM";
     
    @Autowired
    public void configureGlobalSecurity(AuthenticationManagerBuilder auth) throws Exception {
        auth.inMemoryAuthentication().withUser("admin").password("admin").roles("ADMIN")
        .and().withUser("test").password("test").roles("USER");
    }
     
    @Override
    protected void configure(HttpSecurity http) throws Exception {
        http.csrf().disable()
            .authorizeRequests()
            .anyRequest().authenticated()
            .and().httpBasic().realmName(REALM).authenticationEntryPoint(getBasicAuthEntryPoint())
            .and().sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS);//We don't need sessions to be created.
    }
    
    @Bean
    public CustomBasicAuthenticationEntryPoint getBasicAuthEntryPoint(){
        return new CustomBasicAuthenticationEntryPoint();
    }
    
    @Bean
    public AuthenticationManager authenticationManagerBean() throws Exception {
        // altough this seems like useless code,
        // its required to prevend spring boot auto-configuration
        return super.authenticationManagerBean();
    }

}

websocket server configuration: WssBrokerConfig.java

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry;
import org.springframework.security.config.annotation.web.socket.AbstractSecurityWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;

@Configuration
@EnableWebSocketMessageBroker
public class WssBrokerConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer  {

    @Override
    protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
        messages
            .nullDestMatcher().authenticated()
            .simpSubscribeDestMatchers("/topic/notification").permitAll()
            .simpDestMatchers("/**").authenticated()
//            .simpSubscribeDestMatchers("/user/**", "/topic/friends/*").hasRole("USER")
//            .simpTypeMatchers(MESSAGE, SUBSCRIBE).denyAll()
            .anyMessage().denyAll();

    }
    
    @Override
    public void configureMessageBroker(MessageBrokerRegistry config) {
        config.enableSimpleBroker("/topic");
        config.setApplicationDestinationPrefixes("/ws");
    }

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        /**Note: setAllowedOrigins is important here: since we have both http & websocket servers, cross-origin accesses should be enabled */
        registry.addEndpoint("/dm-ws").setAllowedOrigins("*").withSockJS();
    }

    @Bean
    public MappingJackson2HttpMessageConverter mappingJackson2HttpMessageConverter() {
        ObjectMapper mapper = new ObjectMapper();
        mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        MappingJackson2HttpMessageConverter converter =
                new MappingJackson2HttpMessageConverter(mapper);
        return converter;
    }

}

websocket client:WebSocketClient.java

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.CountDownLatch;

import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompFrameHandler;
import org.springframework.messaging.simp.stomp.StompHeaders;
import org.springframework.messaging.simp.stomp.StompSession;
import org.springframework.messaging.simp.stomp.StompSessionHandler;
import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.socket.sockjs.client.SockJsClient;
import org.springframework.web.socket.sockjs.client.Transport;
import org.springframework.web.socket.sockjs.client.WebSocketTransport;

//@JsonIgnoreProperties(ignoreUnknown = true)
public class WebsocketClient {
    public static void main(String[] args){
        // 1.login to rest service
        authToRest();
        // 2.establish websocket connection
        openConnection();
    }
    
    private static HttpHeaders getHeaders(){
        String plainCredentials="admin:admin";
        String base64Credentials = Base64.getEncoder().encodeToString(plainCredentials.getBytes());
         
        HttpHeaders headers = new HttpHeaders();
        headers.add("Authorization", "Basic " + base64Credentials);
        return headers;
    }
    
    public static void authToRest(){
        RestTemplate restTemplate = new RestTemplate(); 
        HttpEntity<String> request = new HttpEntity<String>(getHeaders());
        ResponseEntity<String> response = restTemplate.exchange("http://localhost:8082/dm/login", HttpMethod.GET, request, String.class);
        System.out.println(response.getBody());
    }
    
    public static void openConnection(){
        List<Transport> transports = new ArrayList<>(1);
        transports.add(new WebSocketTransport(new StandardWebSocketClient()));
        WebSocketClient transport = new SockJsClient(transports);
        WebSocketStompClient stompClient = new WebSocketStompClient(transport);
        
        stompClient.setMessageConverter(new MappingJackson2MessageConverter());
        ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
        taskScheduler.afterPropertiesSet();
        stompClient.setTaskScheduler(taskScheduler); // for heartbeats
        StompSessionHandler myHandler = new MyStompHandler();
        String url = "ws://localhost:8082/dm-ws";
        stompClient.connect(url, myHandler);
        
        //block the thread
        CountDownLatch latch = new CountDownLatch(1);
        try {
            latch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    
    private static class MyStompHandler extends StompSessionHandlerAdapter {
        
        @Override
        public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
            session.subscribe("/topic/response", new StompFrameHandler() {
                @Override
                public Type getPayloadType(StompHeaders headers) {
                    return Object.class;
                }
                @Override
                public void handleFrame(StompHeaders headers, Object payload) {
                    System.out.println(payload);
                }
            });
        }

        @Override
        public void handleException(StompSession session, StompCommand command, StompHeaders headers, byte[] payload,
                Throwable exception) {
            System.out.println(exception.getMessage());
        }

        @Override
        public void handleTransportError(StompSession session, Throwable exception) {
            exception.printStackTrace();
            System.out.println("transport error.");
        }
    }
}

But the result turned out to be I can get the rest response but can't establish connection with WebSocket server. This is the error info:

14:31:35.368 [main] DEBUG org.springframework.web.client.RestTemplate - GET request for "http://localhost:8082/dm/login" resulted in 200 (null)
14:31:35.369 [main] DEBUG org.springframework.web.client.RestTemplate - Reading [java.lang.String] as "text/plain;charset=UTF-8" using [org.springframework.http.converter.StringHttpMessageConverter@4b553d26]
Login success to WssBroker...
14:31:35.415 [main] INFO org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler - Initializing ExecutorService 
14:31:35.569 [main] DEBUG org.springframework.web.socket.sockjs.client.RestTemplateXhrTransport - Executing SockJS Info request, url=http://localhost:8082/dm-ws/info
14:31:35.569 [main] DEBUG org.springframework.web.client.RestTemplate - Created GET request for "http://localhost:8082/dm-ws/info"
14:31:35.574 [main] DEBUG org.springframework.web.client.RestTemplate - GET request for "http://localhost:8082/dm-ws/info" resulted in 401 (null); invoking error handler
14:31:35.578 [main] ERROR org.springframework.web.socket.sockjs.client.SockJsClient - Initial SockJS "Info" request to server failed, url=ws://localhost:8082/dm-ws
org.springframework.web.client.HttpClientErrorException: 401 null
    at org.springframework.web.client.DefaultResponseErrorHandler.handleError(DefaultResponseErrorHandler.java:91)
    at org.springframework.web.client.RestTemplate.handleResponse(RestTemplate.java:667)
    at org.springframework.web.client.RestTemplate.doExecute(RestTemplate.java:620)
    at org.springframework.web.client.RestTemplate.execute(RestTemplate.java:595)
    at org.springframework.web.socket.sockjs.client.RestTemplateXhrTransport.executeInfoRequestInternal(RestTemplateXhrTransport.java:138)
    at org.springframework.web.socket.sockjs.client.AbstractXhrTransport.executeInfoRequest(AbstractXhrTransport.java:155)
    at org.springframework.web.socket.sockjs.client.SockJsClient.getServerInfo(SockJsClient.java:286)
    at org.springframework.web.socket.sockjs.client.SockJsClient.doHandshake(SockJsClient.java:254)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:274)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:255)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:235)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:219)
    at com.hspi.dm.console.message.WebsocketClient.openConnection(WebsocketClient.java:66)
    at com.hspi.dm.console.message.WebsocketClient.main(WebsocketClient.java:35)
14:31:35.583 [main] DEBUG org.springframework.messaging.simp.stomp.DefaultStompSession - Failed to connect session id=d8fda5d4-ba5a-7d22-f517-74e939096bfa
org.springframework.web.client.HttpClientErrorException: 401 null
    at org.springframework.web.client.DefaultResponseErrorHandler.handleError(DefaultResponseErrorHandler.java:91)
    at org.springframework.web.client.RestTemplate.handleResponse(RestTemplate.java:667)
    at org.springframework.web.client.RestTemplate.doExecute(RestTemplate.java:620)
    at org.springframework.web.client.RestTemplate.execute(RestTemplate.java:595)
    at org.springframework.web.socket.sockjs.client.RestTemplateXhrTransport.executeInfoRequestInternal(RestTemplateXhrTransport.java:138)
    at org.springframework.web.socket.sockjs.client.AbstractXhrTransport.executeInfoRequest(AbstractXhrTransport.java:155)
    at org.springframework.web.socket.sockjs.client.SockJsClient.getServerInfo(SockJsClient.java:286)
    at org.springframework.web.socket.sockjs.client.SockJsClient.doHandshake(SockJsClient.java:254)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:274)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:255)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:235)
    at org.springframework.web.socket.messaging.WebSocketStompClient.connect(WebSocketStompClient.java:219)
    at com.hspi.dm.console.message.WebsocketClient.openConnection(WebsocketClient.java:66)
    at com.hspi.dm.console.message.WebsocketClient.main(WebsocketClient.java:35)

Could someone help me? Thanks.

Update

From this post, the problem seems to be I didn't secure the websocket endpoint. I tried that two solutions but it didn't work, maybe I got the wrong way to do this. Wishing the correct way to secure websocket endponit.

Dave Pateral
  • 1,415
  • 1
  • 14
  • 21

1 Answers1

8

Well, I misunderstood WebSockets reuse the same authentication information that is found in the HTTP request when the WebSocket connection was made. from spring documentation.

I should have authenticated the HTTP request that upgrades to WebSocket rather than rest service. Meanwhile, there is something need to change with WebSocket configuration.

1.Disable CSRF within WebSocket

Add sameOriginDisabled() to WssBrokerConfig.

@Configuration
public class WssBrokerConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {

    ...

    @Override
    protected boolean sameOriginDisabled() {
        return true;
    }
}

2.Connect with WebSocketHttpHeaders

Construct WebSocketHttpHeaders and add user credential to headers before connect(). The username and password should be encrypted with base64.

String plainCredentials="admin:admin";
String base64Credentials = Base64.getEncoder().encodeToString(plainCredentials.getBytes());

final WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Authorization", "Basic " + base64Credentials);

stompClient.connect(url, headers, myHandler);
Dave Pateral
  • 1,415
  • 1
  • 14
  • 21