Administrator
发布于 2023-01-03 / 89 阅读
0
0

Use WebSocket In SpringBoot

Native websocket

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.5.14</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.wp</groupId>
    <artifactId>native-websocket</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>native-websocket</name>
    <description>native-websocket</description>
    <properties>
        <java.version>1.8</java.version>
    </properties>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>


application.yml

server:
  port: 8082


WebSocketConfiguration

package com.wp.nativewebsocket.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

@Configuration
public class WebSocketConfiguration {

    /**
     * 	注入ServerEndpointExporter,
     * 	这个bean会自动注册使用了@ServerEndpoint注解声明的Websocket endpoint
     */
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}



UserEndpoint

package com.wp.nativewebsocket.endpoint;

import org.springframework.stereotype.Component;

import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArraySet;

@Component
@ServerEndpoint("/myws/{userId}")
public class UserEndpoint {

    /**
     * 线程安全的无序的集合
     */
    private static final CopyOnWriteArraySet<Session> SESSIONS = new CopyOnWriteArraySet<>();

    /**
     * 存储在线连接数
     */
    private static final Map<String, Session> SESSION_POOL = new HashMap<>();

    @OnOpen
    public void onOpen(Session session, @PathParam(value = "userId") String userId) {
        try {
            SESSIONS.add(session);
            SESSION_POOL.put(userId, session);
            System.out.println("【WebSocket消息】有新的连接,总数为:" + SESSIONS.size());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @OnClose
    public void onClose(Session session) {
        try {
            SESSIONS.remove(session);
            System.out.println("【WebSocket消息】连接断开,总数为:" + SESSIONS.size());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @OnMessage
    public void onMessage(String message) {
        System.out.println("【WebSocket消息】收到客户端消息:" + message);
    }

    /**
     * 此为广播消息
     *
     * @param message 消息
     */
    public void sendAllMessage(String message) {
        System.out.println("【WebSocket消息】广播消息:" + message);
        for (Session session : SESSIONS) {
            try {
                if (session.isOpen()) {
                    session.getAsyncRemote().sendText(message);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 此为单点消息
     *
     * @param userId  用户编号
     * @param message 消息
     */
    public void sendOneMessage(String userId, String message) {
        Session session = SESSION_POOL.get(userId);
        if (session != null && session.isOpen()) {
            try {
                synchronized (session) {
                    System.out.println("【WebSocket消息】单点消息:" + message);
                    session.getAsyncRemote().sendText(message);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 此为单点消息(多人)
     *
     * @param userIds 用户编号列表
     * @param message 消息
     */
    public void sendMoreMessage(String[] userIds, String message) {
        for (String userId : userIds) {
            Session session = SESSION_POOL.get(userId);
            if (session != null && session.isOpen()) {
                try {
                    System.out.println("【WebSocket消息】单点消息:" + message);
                    session.getAsyncRemote().sendText(message);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
}



TestController

package com.wp.nativewebsocket.controller;

import com.wp.nativewebsocket.endpoint.UserEndpoint;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/test")
public class TestController {

    @Autowired
    private UserEndpoint userEndpoint;

    @RequestMapping("sendMsg")
    public String sendMsgToClient(@RequestParam("msg") String msg){
        userEndpoint.sendAllMessage(msg);
        return "OK";
    }

}


测试,连接、关闭、客户端发送消息给服务端
http://coolaf.com/zh/tool/chattest
输入: ws://127.0.0.1:8082/myws/3

测试,服务端,发送消息,给客户端
http://localhost:8082/test/sendMsg?msg=bbb

注意,上面的几个注解,首先是他们的包都在 javax.websocket 下。并不是 spring 提供的,而 jdk 自带的。 所以,这种方式,是原生的websocket写法

这里的@ServerEndpoint,就类似于,我们的@RestController+ @RequestMapping注解,标记这个UserEndpoint 类,是websocket中的一个endpoint

Spring Websocket

1.首先,我们需要,自定义一个处理器

package com.example.springwebsocket.handler;

import com.example.springwebsocket.service.WpWebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.socket.*;

public class WpWebSocketHandler implements WebSocketHandler {

    @Autowired
    private WpWebSocketService wpWebSocketService;

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        wpWebSocketService.handleOpen(session);
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        if (message instanceof TextMessage) {
            TextMessage textMessage = (TextMessage) message;
            wpWebSocketService.handleMessage(session, textMessage.getPayload());
        }
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        wpWebSocketService.handleError(session, exception);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        wpWebSocketService.handleClose(session);
    }

    /**
     * 是否支持发送部分消息
     * @return
     */
    @Override
    public boolean supportsPartialMessages() {
        return false;
    }
}




  1. 接下来,我们需要创建一个拦截器
package com.example.springwebsocket.intercepter;

import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import java.util.Map;

public class WpSocketInterceptor implements HandshakeInterceptor {
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        if (request instanceof ServletServerHttpRequest) {
            ServletServerHttpRequest servletServerHttpRequest = (ServletServerHttpRequest) request;
            // 模拟用户(通常利用JWT令牌解析用户信息)
            String userId = servletServerHttpRequest.getServletRequest().getParameter("uid");
            // TODO 判断用户是否存在

            // 这里,将uid放到attributes中后,接下来,我们就可以在session中,获取到这个uid,从而区分 多个的客户端了
            attributes.put("uid", userId);
            return true;
        }
        return false;
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {

    }
}

  1. 接着,我们需要,将处理器 和 拦截器,和对应的url路径,绑定起来
package com.example.springwebsocket.config;

import com.example.springwebsocket.handler.WpWebSocketHandler;
import com.example.springwebsocket.intercepter.WpSocketInterceptor;
import com.example.springwebsocket.service.WpWebSocketService;
import com.example.springwebsocket.service.WpWebSocketServiceImpl;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

@Configuration
@EnableWebSocket
public class WpWebSocketConfiguration implements WebSocketConfigurer {

    @Bean
    public WpWebSocketService webSocket() {
        return new WpWebSocketServiceImpl();
    }

    @Bean
    public WpWebSocketHandler wpWebSocketHandler() {
        return new WpWebSocketHandler();
    }

    @Bean
    public WpSocketInterceptor wpSocketInterceptor() {
        return new WpSocketInterceptor();
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(wpWebSocketHandler(), "myws/message")
                .addInterceptors(wpSocketInterceptor())
                .setAllowedOrigins("*");
    }
}



  • @EnableWebSocket:开启WebSocket功能
  • addHandler:添加处理器
  • addInterceptors:添加拦截器
  • setAllowedOrigins:设置允许跨域(允许所有请求来源)


  1. 接下来,我们就需要编写业务类
package com.example.springwebsocket.service;

import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.util.Set;

public interface WpWebSocketService {
    /**
     * 会话开始回调
     *
     * @param session 会话
     */
    void handleOpen(WebSocketSession session);

    /**
     * 会话结束回调
     *
     * @param session 会话
     */
    void handleClose(WebSocketSession session);

    /**
     * 处理消息
     *
     * @param session 会话
     * @param message 接收的消息
     */
    void handleMessage(WebSocketSession session, String message);

    /**
     * 发送消息
     *
     * @param session 当前会话
     * @param message 要发送的消息
     * @throws IOException 发送io异常
     */
    void sendMessage(WebSocketSession session, String message) throws IOException;

    /**
     * 发送消息
     *
     * @param userId  用户id
     * @param message 要发送的消息
     * @throws IOException 发送io异常
     */
    void sendMessage(String userId, TextMessage message) throws IOException;

    /**
     * 发送消息
     *
     * @param userId  用户id
     * @param message 要发送的消息
     * @throws IOException 发送io异常
     */
    void sendMessage(String userId, String message) throws IOException;

    /**
     * 发送消息
     *
     * @param session 当前会话
     * @param message 要发送的消息
     * @throws IOException 发送io异常
     */
    void sendMessage(WebSocketSession session, TextMessage message) throws IOException;

    /**
     * 广播
     *
     * @param message 字符串消息
     * @throws IOException 异常
     */
    void broadCast(String message) throws IOException;

    /**
     * 广播
     *
     * @param message 文本消息
     * @throws IOException 异常
     */
    void broadCast(TextMessage message) throws IOException;

    /**
     * 处理会话异常
     *
     * @param session 会话
     * @param error   异常
     */
    void handleError(WebSocketSession session, Throwable error);

    /**
     * 获得所有的 websocket 会话
     *
     * @return 所有 websocket 会话
     */
    Set<WebSocketSession> getSessions();

    /**
     * 得到当前连接数
     *
     * @return 连接数
     */
    int getConnectionCount();
}


package com.example.springwebsocket.service;

import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;

public class WpWebSocketServiceImpl implements WpWebSocketService{
    /**
     * 在线连接数(线程安全)
     */
    private final AtomicInteger connectionCount = new AtomicInteger(0);

    /**
     * 线程安全的无序集合(存储会话)
     */
    private final CopyOnWriteArraySet<WebSocketSession> sessions = new CopyOnWriteArraySet<>();



    @Override
    public void handleOpen(WebSocketSession session) {
        sessions.add(session);
        int count = connectionCount.incrementAndGet();
        System.out.println("a new connection opened,current online count:"+ count);
    }

    @Override
    public void handleClose(WebSocketSession session) {
        sessions.remove(session);
        int count = connectionCount.decrementAndGet();
        System.out.println("a new connection closed,current online count: "+count);
    }

    @Override
    public void handleMessage(WebSocketSession session, String message) {
        // 只处理前端传来的文本消息,并且直接丢弃了客户端传来的消息
        System.out.println("received a message:"+ message);
    }

    @Override
    public void sendMessage(WebSocketSession session, String message) throws IOException {
        this.sendMessage(session, new TextMessage(message));
    }

    @Override
    public void sendMessage(String userId, TextMessage message) throws IOException {
        Optional<WebSocketSession> userSession = sessions.stream().filter(session -> {
            if (!session.isOpen()) {
                return false;
            }
            Map<String, Object> attributes = session.getAttributes();
            if (!attributes.containsKey("uid")){
                return false;
            }
            String uid = (String) attributes.get("uid");
            return uid.equals(userId);
        }).findFirst();
        if (userSession.isPresent()) {
            userSession.get().sendMessage(message);
        }
    }

    @Override
    public void sendMessage(String userId, String message) throws IOException {
        this.sendMessage(userId, new TextMessage(message));
    }

    @Override
    public void sendMessage(WebSocketSession session, TextMessage message) throws IOException {
        session.sendMessage(message);
    }

    @Override
    public void broadCast(String message) throws IOException {
        for (WebSocketSession session : sessions) {
            if (!session.isOpen()) {
                continue;
            }
            this.sendMessage(session, message);
        }
    }

    @Override
    public void broadCast(TextMessage message) throws IOException {
        for (WebSocketSession session : sessions) {
            if (!session.isOpen()) {
                continue;
            }
            session.sendMessage(message);
        }
    }

    @Override
    public void handleError(WebSocketSession session, Throwable error) {
        System.out.println("websocket error:"+error.getMessage()+" ,session id:"+ session.getId());
        System.err.println(error);
    }

    @Override
    public Set<WebSocketSession> getSessions() {
        return sessions;
    }

    @Override
    public int getConnectionCount() {
        return connectionCount.get();
    }
}

  1. 创建一个controller,用于服务端,向客户端发送消息
package com.example.springwebsocket.controller;

import com.example.springwebsocket.service.WpWebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.io.IOException;

@RestController
@RequestMapping("/test")
public class TestController {

    @Autowired
    private WpWebSocketService wpWebSocketService;

    @RequestMapping("sendMsg")
    public String sendMsgToClient(@RequestParam("msg") String msg) throws IOException {
        wpWebSocketService.sendMessage("5",msg);
        return "OK";
    }
}





测试,连接、关闭、客户端发送消息给服务端
http://coolaf.com/zh/tool/chattest
输入: ws://127.0.0.1:8083/myws/message?uid=5

测试,服务端,发送消息,给客户端
http://localhost:8083/test/sendMsg?msg=gogogo

拦截器中的uid,设置到attributes中

这个section中,我们说下,在拦截器中,将uid的值设置到attributes中后。为什么,我们可以在WebSocketSession中,能获取到uid的值呢?


我们看下,下面的代码,就懂了:

// org.springframework.web.socket.server.support.WebSocketHttpRequestHandler#handleRequest


@Override
	public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
			throws ServletException, IOException {

		ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
		ServerHttpResponse response = new ServletServerHttpResponse(servletResponse);

		HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.wsHandler);
		HandshakeFailureException failure = null;

		try {
			if (logger.isDebugEnabled()) {
				logger.debug(servletRequest.getMethod() + " " + servletRequest.getRequestURI());
			}
			Map<String, Object> attributes = new HashMap<>();
			
			// 在这一步,调用握手之前的方法,也即调用到了,我们HandshakeInterceptor的beforeHandshake方法,在这个beforeHandshake方法中,我们将uid的值,设置到attributes中
			if (!chain.applyBeforeHandshake(request, response, attributes)) {
				return;
			}
			// 接下来,在这个执行握手的方法中,我们会根据很多属性,创建一个StandardWebSocketSession,在创建时,会把这个attributes,传入到StandardWebSocketSession的构造方法中
			this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
			chain.applyAfterHandshake(request, response, null);
		}
		catch (HandshakeFailureException ex) {
			failure = ex;
		}
		catch (Exception ex) {
			failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex);
		}
		finally {
			if (failure != null) {
				chain.applyAfterHandshake(request, response, failure);
				response.close();
				throw failure;
			}
			response.close();
		}
	}

评论