Native websocket
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.5.14</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.wp</groupId>
<artifactId>native-websocket</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>native-websocket</name>
<description>native-websocket</description>
<properties>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
application.yml
server:
port: 8082
WebSocketConfiguration
package com.wp.nativewebsocket.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WebSocketConfiguration {
/**
* 注入ServerEndpointExporter,
* 这个bean会自动注册使用了@ServerEndpoint注解声明的Websocket endpoint
*/
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
UserEndpoint
package com.wp.nativewebsocket.endpoint;
import org.springframework.stereotype.Component;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArraySet;
@Component
@ServerEndpoint("/myws/{userId}")
public class UserEndpoint {
/**
* 线程安全的无序的集合
*/
private static final CopyOnWriteArraySet<Session> SESSIONS = new CopyOnWriteArraySet<>();
/**
* 存储在线连接数
*/
private static final Map<String, Session> SESSION_POOL = new HashMap<>();
@OnOpen
public void onOpen(Session session, @PathParam(value = "userId") String userId) {
try {
SESSIONS.add(session);
SESSION_POOL.put(userId, session);
System.out.println("【WebSocket消息】有新的连接,总数为:" + SESSIONS.size());
} catch (Exception e) {
e.printStackTrace();
}
}
@OnClose
public void onClose(Session session) {
try {
SESSIONS.remove(session);
System.out.println("【WebSocket消息】连接断开,总数为:" + SESSIONS.size());
} catch (Exception e) {
e.printStackTrace();
}
}
@OnMessage
public void onMessage(String message) {
System.out.println("【WebSocket消息】收到客户端消息:" + message);
}
/**
* 此为广播消息
*
* @param message 消息
*/
public void sendAllMessage(String message) {
System.out.println("【WebSocket消息】广播消息:" + message);
for (Session session : SESSIONS) {
try {
if (session.isOpen()) {
session.getAsyncRemote().sendText(message);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
/**
* 此为单点消息
*
* @param userId 用户编号
* @param message 消息
*/
public void sendOneMessage(String userId, String message) {
Session session = SESSION_POOL.get(userId);
if (session != null && session.isOpen()) {
try {
synchronized (session) {
System.out.println("【WebSocket消息】单点消息:" + message);
session.getAsyncRemote().sendText(message);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
/**
* 此为单点消息(多人)
*
* @param userIds 用户编号列表
* @param message 消息
*/
public void sendMoreMessage(String[] userIds, String message) {
for (String userId : userIds) {
Session session = SESSION_POOL.get(userId);
if (session != null && session.isOpen()) {
try {
System.out.println("【WebSocket消息】单点消息:" + message);
session.getAsyncRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
}
TestController
package com.wp.nativewebsocket.controller;
import com.wp.nativewebsocket.endpoint.UserEndpoint;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/test")
public class TestController {
@Autowired
private UserEndpoint userEndpoint;
@RequestMapping("sendMsg")
public String sendMsgToClient(@RequestParam("msg") String msg){
userEndpoint.sendAllMessage(msg);
return "OK";
}
}
测试,连接、关闭、客户端发送消息给服务端
http://coolaf.com/zh/tool/chattest
输入: ws://127.0.0.1:8082/myws/3
测试,服务端,发送消息,给客户端
http://localhost:8082/test/sendMsg?msg=bbb
注意,上面的几个注解,首先是他们的包都在 javax.websocket 下。并不是 spring 提供的,而 jdk 自带的。 所以,这种方式,是原生的websocket写法
这里的@ServerEndpoint,就类似于,我们的@RestController+ @RequestMapping注解,标记这个UserEndpoint 类,是websocket中的一个endpoint
Spring Websocket
1.首先,我们需要,自定义一个处理器
package com.example.springwebsocket.handler;
import com.example.springwebsocket.service.WpWebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.socket.*;
public class WpWebSocketHandler implements WebSocketHandler {
@Autowired
private WpWebSocketService wpWebSocketService;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
wpWebSocketService.handleOpen(session);
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
if (message instanceof TextMessage) {
TextMessage textMessage = (TextMessage) message;
wpWebSocketService.handleMessage(session, textMessage.getPayload());
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
wpWebSocketService.handleError(session, exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
wpWebSocketService.handleClose(session);
}
/**
* 是否支持发送部分消息
* @return
*/
@Override
public boolean supportsPartialMessages() {
return false;
}
}
- 接下来,我们需要创建一个拦截器
package com.example.springwebsocket.intercepter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import java.util.Map;
public class WpSocketInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletServerHttpRequest = (ServletServerHttpRequest) request;
// 模拟用户(通常利用JWT令牌解析用户信息)
String userId = servletServerHttpRequest.getServletRequest().getParameter("uid");
// TODO 判断用户是否存在
// 这里,将uid放到attributes中后,接下来,我们就可以在session中,获取到这个uid,从而区分 多个的客户端了
attributes.put("uid", userId);
return true;
}
return false;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
}
}
- 接着,我们需要,将处理器 和 拦截器,和对应的url路径,绑定起来
package com.example.springwebsocket.config;
import com.example.springwebsocket.handler.WpWebSocketHandler;
import com.example.springwebsocket.intercepter.WpSocketInterceptor;
import com.example.springwebsocket.service.WpWebSocketService;
import com.example.springwebsocket.service.WpWebSocketServiceImpl;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WpWebSocketConfiguration implements WebSocketConfigurer {
@Bean
public WpWebSocketService webSocket() {
return new WpWebSocketServiceImpl();
}
@Bean
public WpWebSocketHandler wpWebSocketHandler() {
return new WpWebSocketHandler();
}
@Bean
public WpSocketInterceptor wpSocketInterceptor() {
return new WpSocketInterceptor();
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(wpWebSocketHandler(), "myws/message")
.addInterceptors(wpSocketInterceptor())
.setAllowedOrigins("*");
}
}
- @EnableWebSocket:开启WebSocket功能
- addHandler:添加处理器
- addInterceptors:添加拦截器
- setAllowedOrigins:设置允许跨域(允许所有请求来源)
- 接下来,我们就需要编写业务类
package com.example.springwebsocket.service;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.Set;
public interface WpWebSocketService {
/**
* 会话开始回调
*
* @param session 会话
*/
void handleOpen(WebSocketSession session);
/**
* 会话结束回调
*
* @param session 会话
*/
void handleClose(WebSocketSession session);
/**
* 处理消息
*
* @param session 会话
* @param message 接收的消息
*/
void handleMessage(WebSocketSession session, String message);
/**
* 发送消息
*
* @param session 当前会话
* @param message 要发送的消息
* @throws IOException 发送io异常
*/
void sendMessage(WebSocketSession session, String message) throws IOException;
/**
* 发送消息
*
* @param userId 用户id
* @param message 要发送的消息
* @throws IOException 发送io异常
*/
void sendMessage(String userId, TextMessage message) throws IOException;
/**
* 发送消息
*
* @param userId 用户id
* @param message 要发送的消息
* @throws IOException 发送io异常
*/
void sendMessage(String userId, String message) throws IOException;
/**
* 发送消息
*
* @param session 当前会话
* @param message 要发送的消息
* @throws IOException 发送io异常
*/
void sendMessage(WebSocketSession session, TextMessage message) throws IOException;
/**
* 广播
*
* @param message 字符串消息
* @throws IOException 异常
*/
void broadCast(String message) throws IOException;
/**
* 广播
*
* @param message 文本消息
* @throws IOException 异常
*/
void broadCast(TextMessage message) throws IOException;
/**
* 处理会话异常
*
* @param session 会话
* @param error 异常
*/
void handleError(WebSocketSession session, Throwable error);
/**
* 获得所有的 websocket 会话
*
* @return 所有 websocket 会话
*/
Set<WebSocketSession> getSessions();
/**
* 得到当前连接数
*
* @return 连接数
*/
int getConnectionCount();
}
package com.example.springwebsocket.service;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;
public class WpWebSocketServiceImpl implements WpWebSocketService{
/**
* 在线连接数(线程安全)
*/
private final AtomicInteger connectionCount = new AtomicInteger(0);
/**
* 线程安全的无序集合(存储会话)
*/
private final CopyOnWriteArraySet<WebSocketSession> sessions = new CopyOnWriteArraySet<>();
@Override
public void handleOpen(WebSocketSession session) {
sessions.add(session);
int count = connectionCount.incrementAndGet();
System.out.println("a new connection opened,current online count:"+ count);
}
@Override
public void handleClose(WebSocketSession session) {
sessions.remove(session);
int count = connectionCount.decrementAndGet();
System.out.println("a new connection closed,current online count: "+count);
}
@Override
public void handleMessage(WebSocketSession session, String message) {
// 只处理前端传来的文本消息,并且直接丢弃了客户端传来的消息
System.out.println("received a message:"+ message);
}
@Override
public void sendMessage(WebSocketSession session, String message) throws IOException {
this.sendMessage(session, new TextMessage(message));
}
@Override
public void sendMessage(String userId, TextMessage message) throws IOException {
Optional<WebSocketSession> userSession = sessions.stream().filter(session -> {
if (!session.isOpen()) {
return false;
}
Map<String, Object> attributes = session.getAttributes();
if (!attributes.containsKey("uid")){
return false;
}
String uid = (String) attributes.get("uid");
return uid.equals(userId);
}).findFirst();
if (userSession.isPresent()) {
userSession.get().sendMessage(message);
}
}
@Override
public void sendMessage(String userId, String message) throws IOException {
this.sendMessage(userId, new TextMessage(message));
}
@Override
public void sendMessage(WebSocketSession session, TextMessage message) throws IOException {
session.sendMessage(message);
}
@Override
public void broadCast(String message) throws IOException {
for (WebSocketSession session : sessions) {
if (!session.isOpen()) {
continue;
}
this.sendMessage(session, message);
}
}
@Override
public void broadCast(TextMessage message) throws IOException {
for (WebSocketSession session : sessions) {
if (!session.isOpen()) {
continue;
}
session.sendMessage(message);
}
}
@Override
public void handleError(WebSocketSession session, Throwable error) {
System.out.println("websocket error:"+error.getMessage()+" ,session id:"+ session.getId());
System.err.println(error);
}
@Override
public Set<WebSocketSession> getSessions() {
return sessions;
}
@Override
public int getConnectionCount() {
return connectionCount.get();
}
}
- 创建一个controller,用于服务端,向客户端发送消息
package com.example.springwebsocket.controller;
import com.example.springwebsocket.service.WpWebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
@RestController
@RequestMapping("/test")
public class TestController {
@Autowired
private WpWebSocketService wpWebSocketService;
@RequestMapping("sendMsg")
public String sendMsgToClient(@RequestParam("msg") String msg) throws IOException {
wpWebSocketService.sendMessage("5",msg);
return "OK";
}
}
测试,连接、关闭、客户端发送消息给服务端
http://coolaf.com/zh/tool/chattest
输入: ws://127.0.0.1:8083/myws/message?uid=5
测试,服务端,发送消息,给客户端
http://localhost:8083/test/sendMsg?msg=gogogo
拦截器中的uid,设置到attributes中
这个section中,我们说下,在拦截器中,将uid的值设置到attributes中后。为什么,我们可以在WebSocketSession中,能获取到uid的值呢?
我们看下,下面的代码,就懂了:
// org.springframework.web.socket.server.support.WebSocketHttpRequestHandler#handleRequest
@Override
public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
throws ServletException, IOException {
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
ServerHttpResponse response = new ServletServerHttpResponse(servletResponse);
HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.wsHandler);
HandshakeFailureException failure = null;
try {
if (logger.isDebugEnabled()) {
logger.debug(servletRequest.getMethod() + " " + servletRequest.getRequestURI());
}
Map<String, Object> attributes = new HashMap<>();
// 在这一步,调用握手之前的方法,也即调用到了,我们HandshakeInterceptor的beforeHandshake方法,在这个beforeHandshake方法中,我们将uid的值,设置到attributes中
if (!chain.applyBeforeHandshake(request, response, attributes)) {
return;
}
// 接下来,在这个执行握手的方法中,我们会根据很多属性,创建一个StandardWebSocketSession,在创建时,会把这个attributes,传入到StandardWebSocketSession的构造方法中
this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
chain.applyAfterHandshake(request, response, null);
}
catch (HandshakeFailureException ex) {
failure = ex;
}
catch (Exception ex) {
failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), ex);
}
finally {
if (failure != null) {
chain.applyAfterHandshake(request, response, failure);
response.close();
throw failure;
}
response.close();
}
}