Parcourir la source

优化websocket通讯,加上token认证
新增获取聊天消息

fangzhen il y a 6 mois
Parent
commit
f44246c41d

+ 13 - 7
ruoyi-common/src/main/java/com/ruoyi/common/utils/SecurityUtils.java

@@ -1,17 +1,18 @@
 package com.ruoyi.common.utils;
 
-import java.util.Collection;
-import java.util.List;
-import java.util.stream.Collectors;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
-import org.springframework.util.PatternMatchUtils;
 import com.ruoyi.common.constant.Constants;
 import com.ruoyi.common.constant.HttpStatus;
 import com.ruoyi.common.core.domain.entity.SysRole;
 import com.ruoyi.common.core.domain.model.LoginUser;
 import com.ruoyi.common.exception.ServiceException;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
+import org.springframework.util.PatternMatchUtils;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * 安全服务工具类
@@ -175,4 +176,9 @@ public class SecurityUtils
                 .anyMatch(x -> Constants.SUPER_ADMIN.equals(x) || PatternMatchUtils.simpleMatch(x, role));
     }
 
+
+    public static void setAuthentication(Authentication authentication) {
+        SecurityContextHolder.getContext().setAuthentication(authentication);
+    }
+
 }

+ 1 - 1
ruoyi-framework/src/main/java/com/ruoyi/framework/config/SecurityConfig.java

@@ -132,7 +132,7 @@ public class SecurityConfig {
                     requests.antMatchers("/login","/phoneLogin", "/register", "/captchaImage","/system/user/profile/updatePwdBySmsCode").permitAll()
                             // 静态资源,可匿名访问
                             .antMatchers(HttpMethod.GET, "/", "/*.html", "/**/*.html", "/**/*.css", "/**/*.js", "/profile/**").permitAll()
-                            .antMatchers("/swagger-ui.html", "/swagger-resources/**", "/webjars/**", "/*/api-docs", "/druid/**", "/websocket/**").permitAll()
+                            .antMatchers("/swagger-ui.html", "/swagger-resources/**", "/webjars/**", "/*/api-docs", "/druid/**").permitAll()
                             // 除上面外的所有请求全部需要鉴权认证
                             .anyRequest().authenticated();
                 })

+ 13 - 10
ruoyi-framework/src/main/java/com/ruoyi/framework/web/service/TokenService.java

@@ -1,14 +1,5 @@
 package com.ruoyi.framework.web.service;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.concurrent.TimeUnit;
-import javax.servlet.http.HttpServletRequest;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.beans.factory.annotation.Value;
-import org.springframework.stereotype.Component;
 import com.ruoyi.common.constant.CacheConstants;
 import com.ruoyi.common.constant.Constants;
 import com.ruoyi.common.core.domain.model.LoginUser;
@@ -22,6 +13,17 @@ import eu.bitwalker.useragentutils.UserAgent;
 import io.jsonwebtoken.Claims;
 import io.jsonwebtoken.Jwts;
 import io.jsonwebtoken.SignatureAlgorithm;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
 
 /**
  * token验证处理
@@ -216,7 +218,8 @@ public class TokenService
      */
     private String getToken(HttpServletRequest request)
     {
-        String token = request.getHeader(header);
+//        String token = request.getHeader(header);
+        String token = Optional.ofNullable(request.getHeader(header)).orElse(request.getParameter(header));
         if (StringUtils.isNotEmpty(token) && token.startsWith(Constants.TOKEN_PREFIX))
         {
             token = token.replace(Constants.TOKEN_PREFIX, "");

+ 58 - 0
ruoyi-framework/src/main/java/com/ruoyi/framework/webSocket/SemaphoreUtils.java

@@ -0,0 +1,58 @@
+package com.ruoyi.framework.webSocket;
+
+import java.util.concurrent.Semaphore;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * 信号量相关处理
+ * 
+ * @author ruoyi
+ */
+public class SemaphoreUtils
+{
+    /**
+     * SemaphoreUtils 日志控制器
+     */
+    private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class);
+
+    /**
+     * 获取信号量
+     * 
+     * @param semaphore
+     * @return
+     */
+    public static boolean tryAcquire(Semaphore semaphore)
+    {
+        boolean flag = false;
+
+        try
+        {
+            flag = semaphore.tryAcquire();
+        }
+        catch (Exception e)
+        {
+            LOGGER.error("获取信号量异常", e);
+        }
+
+        return flag;
+    }
+
+    /**
+     * 释放信号量
+     * 
+     * @param semaphore
+     */
+    public static void release(Semaphore semaphore)
+    {
+
+        try
+        {
+            semaphore.release();
+        }
+        catch (Exception e)
+        {
+            LOGGER.error("释放信号量异常", e);
+        }
+    }
+}

+ 3 - 2
ruoyi-framework/src/main/java/com/ruoyi/framework/webSocket/WebSocketConfig.java

@@ -6,7 +6,8 @@ import org.springframework.web.socket.server.standard.ServerEndpointExporter;
 
 /**
  * websocket 配置
- *
+ * 
+ * @author ruoyi
  */
 @Configuration
 public class WebSocketConfig
@@ -16,4 +17,4 @@ public class WebSocketConfig
     {
         return new ServerEndpointExporter();
     }
-}
+}

+ 205 - 90
ruoyi-framework/src/main/java/com/ruoyi/framework/webSocket/WebSocketServer.java

@@ -1,38 +1,71 @@
 package com.ruoyi.framework.webSocket;
 
 import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson2.JSON;
 import com.ruoyi.common.utils.DateUtils;
+import com.ruoyi.common.utils.SecurityUtils;
 import com.ruoyi.system.domain.CommunityChatMsg;
 import com.ruoyi.system.service.ICommunityChatMsgService;
+import io.netty.util.HashedWheelTimer;
+import io.netty.util.Timeout;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.security.core.Authentication;
 import org.springframework.stereotype.Component;
 
 import javax.websocket.*;
 import javax.websocket.server.PathParam;
 import javax.websocket.server.ServerEndpoint;
 import java.io.IOException;
-import java.util.List;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CopyOnWriteArraySet;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
 
+/**
+ * websocket 消息处理
+ *
+ * @author stronger
+ */
 @Component
-@ServerEndpoint("/websocket/{userId}")
+@ServerEndpoint("/websocket/message")
 public class WebSocketServer {
-    private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);
+    /*========================声明类变量,意在所有实例共享=================================================*/
+    /**
+     * WebSocketServer 日志控制器
+     */
+    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);
 
-    //静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
-    private static AtomicInteger onlineNum = new AtomicInteger();
+    /**
+     * 默认最多允许同时在线人数100
+     */
+    public static int socketMaxOnlineCount = 100;
 
-    //concurrent包的线程安全Set,用来存放每个客户端对应的WebSocketServer对象。
-    private static ConcurrentHashMap<String, Session> sessionPools = new ConcurrentHashMap<>();
+    private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);
 
+    HashedWheelTimer timer = new HashedWheelTimer(1, TimeUnit.SECONDS, 8);
+    /**
+     * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
+     */
+    private static final CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>();
     /**
-     * 线程安全list,用来存放 在线客户端账号
+     * 连接数
      */
-    public static List<String> userList = new CopyOnWriteArrayList<>();
+    private static final AtomicInteger count = new AtomicInteger();
+
+    /*========================声明实例变量,意在每个实例独享=======================================================*/
+    /**
+     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
+     */
+    private Session session;
+    /**
+     * 用户id
+     */
+    private String sid = "";
 
     public static ICommunityChatMsgService chatMsgService;
 
@@ -42,60 +75,191 @@ public class WebSocketServer {
     }
 
     /**
-     * 连接成功
-     *
-     * @param session
-     * @param userId
+     * 连接建立成功调用的方法
      */
     @OnOpen
-    public void onOpen(Session session, @PathParam(value = "userId") String userId) {
-        sessionPools.put(userId, session);
-        if (!userList.contains(userId)) {
-            addOnlineCount();
-            userList.add(userId);
+    public void onOpen(Session session) throws Exception {
+        // 尝试获取信号量
+        boolean semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
+        if (!semaphoreFlag) {
+            // 未获取到信号量
+            LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
+            // 给当前Session 登录用户发送消息
+            WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
+            session.close();
+        } else {
+            // 返回此会话的经过身份验证的用户,如果此会话没有经过身份验证的用户,则返回null
+            Authentication authentication = (Authentication) session.getUserPrincipal();
+            SecurityUtils.setAuthentication(authentication);
+            Long userId = SecurityUtils.getUserId();
+            this.session = session;
+            //如果存在就先删除一个,防止重复推送消息
+            for (WebSocketServer webSocket : webSocketSet) {
+                if (webSocket.sid.equals(userId)) {
+                    webSocketSet.remove(webSocket);
+                    count.getAndDecrement();
+                }
+            }
+            count.getAndIncrement();
+            webSocketSet.add(this);
+            this.sid = String.valueOf(userId);
+            LOGGER.info("\n 当前人数 - {}", count);
+            WebSocketUsers.sendMessageToUserByText(session, "连接成功");
         }
-        log.debug("ID为【" + userId + "】的用户加入websocket!当前在线人数为:" + onlineNum);
-        log.debug("当前在线:" + userList);
     }
 
     /**
-     * 关闭连接
-     *
-     * @param userId
+     * 连接关闭时处理
      */
     @OnClose
-    public void onClose(@PathParam(value = "userId") String userId) {
-        sessionPools.remove(userId);
-        if (userList.contains(userId)) {
-            userList.remove(userId);
-            subOnlineCount();
+    public void onClose(Session session) {
+        LOGGER.info("\n 关闭连接 - {}", session);
+        // 移除用户
+        webSocketSet.remove(session);
+        // 获取到信号量则需释放
+        SemaphoreUtils.release(socketSemaphore);
+    }
+
+    /**
+     * 抛出异常时处理
+     */
+    @OnError
+    public void onError(Session session, Throwable exception) throws Exception {
+        if (session.isOpen()) {
+            // 关闭连接
+            session.close();
         }
-        log.debug(userId + "断开webSocket连接!当前人数为" + onlineNum);
+        String sessionId = session.getId();
+        LOGGER.info("\n 连接异常 - {}", sessionId);
+        LOGGER.info("\n 异常信息 - {}", exception);
+        // 移出用户
+        webSocketSet.remove(session);
+        // 获取到信号量则需释放
+        SemaphoreUtils.release(socketSemaphore);
+    }
 
+    /**
+     * 服务器接收到客户端消息时调用的方法
+     */
+    @OnMessage
+    public void onMessage(String message, Session session) {
+        Authentication authentication = (Authentication) session.getUserPrincipal();
+        LOGGER.info("收到来自" + sid + "的信息:" + message);
+        // 实时更新
+        this.refresh(sid, authentication,message);
+        WebSocketUsers.sendMessageToUserByText(session, "成功发送一条消息");
     }
 
     /**
-     * 消息监听
+     * 刷新定时任务,发送信息
+     */
+    private void refresh(String userId, Authentication authentication,String message) {
+//        this.start(5000L, task -> {
+//            // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
+//            if (WebSocketServer.isConn(userId)) {
+//                // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约
+//                SecurityUtils.setAuthentication(authentication);
+//                // 从数据库或者缓存中获取信息,构建自定义的Bean
+//                CommunityChatMsg chatMsg = saveChat(message);
+//                // TODO判断数据是否有更新
+//                // 发送最新数据给前端
+//                WebSocketServer.sendInfo("JSON", chatMsg, String.valueOf(chatMsg.getReceiverId()));
+//                // 设置返回值,判断是否需要继续执行
+//                return true;
+//            }
+//            return false;
+//        });
+        // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
+        if (WebSocketServer.isConn(userId)) {
+            // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约
+            SecurityUtils.setAuthentication(authentication);
+            // 从数据库或者缓存中获取信息,构建自定义的Bean
+            CommunityChatMsg chatMsg = saveChat(message);
+            // TODO判断数据是否有更新
+            // 发送最新数据给前端
+            if (WebSocketServer.isConn(String.valueOf(chatMsg.getReceiverId()))) {
+                WebSocketServer.sendInfo("JSON", chatMsg, String.valueOf(chatMsg.getReceiverId()));
+            }
+        }
+    }
+
+    private void start(long delay, Function<Timeout, Boolean> function) {
+        timer.newTimeout(t -> {
+            // 获取返回值,判断是否执行
+            Boolean result = function.apply(t);
+            if (result) {
+                timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS);
+            }
+        }, delay, TimeUnit.MILLISECONDS);
+    }
+
+    /**
+     * 判断是否有链接
      *
-     * @param message
-     * @throws IOException
+     * @return
      */
-    @OnMessage
-    public void onMessage(String message) throws IOException {
+    public static boolean isConn(String sid) {
+        for (WebSocketServer item : webSocketSet) {
+            if (item.sid.equals(sid)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    /**
+     * 群发自定义消息
+     * 或者指定用户发送消息
+     */
+    public static void sendInfo(String type, Object data, @PathParam("sid") String sid) {
+        // 遍历WebSocketServer对象集合,如果符合条件就推送
+        for (WebSocketServer item : webSocketSet) {
+            try {
+                //这里可以设定只推送给这个sid的,为null则全部推送
+                if (sid == null) {
+                    item.sendMessage(type, data);
+                } else if (item.sid.equals(sid)) {
+                    item.sendMessage(type, data);
+                }
+            } catch (IOException ignored) {
+            }
+        }
+    }
+
+    /**
+     * 实现服务器主动推送
+     */
+    private void sendMessage(String type, Object data) throws IOException {
+        Map<String, Object> result = new HashMap<>();
+        result.put("type", type);
+        result.put("data", data);
+        this.session.getAsyncRemote().sendText(JSON.toJSONString(result));
+    }
+
+
+    /**
+     * 保存聊天记录
+     *
+     * @param message 消息
+     * @return 聊天记录
+     */
+    public static CommunityChatMsg saveChat(String message) {
         JSONObject jsonObject = JSONObject.parseObject(message);
         Long sendUserId = jsonObject.getLong("sendUserId");
         Long receiveUserId = jsonObject.getLong("receiveUserId");
         String type = jsonObject.getString("type");     //消息分类 chat
         int messageType = jsonObject.getInteger("messageType");     //消息类型
         JSONObject sendText = jsonObject.getJSONObject("sendText");
+        CommunityChatMsg chatMsg = null;
         if (type.equals(MessageType.CHAT.getType())) {
-            log.debug("聊天消息推送");
-            sendToUser(String.valueOf(receiveUserId), JSONObject.toJSONString(jsonObject));
+            LOGGER.debug("聊天消息推送");
+//            sendToUser(String.valueOf(receiveUserId), JSONObject.toJSONString(jsonObject));
 
-            CommunityChatMsg chatMsg = new CommunityChatMsg();
+            chatMsg = new CommunityChatMsg();
             chatMsg.setSenderId(sendUserId);
             chatMsg.setReceiverId(receiveUserId);
             chatMsg.setCreateBy(sendUserId);
+            chatMsg.setRead(false);
             chatMsg.setMessageType(messageType);
             chatMsg.setCreateTime(DateUtils.parseDate(DateUtils.getTime()));
             switch (messageType) {
@@ -129,58 +293,9 @@ public class WebSocketServer {
                     break;
 
             }
-
             //存储历史消息
             chatMsgService.save(chatMsg);
         }
+        return chatMsg;
     }
-
-    /**
-     * 连接错误
-     *
-     * @param session
-     * @param throwable
-     * @throws IOException
-     */
-    @OnError
-    public void onError(Session session, Throwable throwable) throws IOException {
-        log.error("websocket连接错误!");
-        throwable.printStackTrace();
-    }
-
-    /**
-     * 发送消息
-     */
-    public void sendMessage(Session session, String message) throws IOException, EncodeException {
-        if (session != null) {
-            synchronized (session) {
-                session.getBasicRemote().sendText(message);
-            }
-        }
-    }
-
-    /**
-     * 给指定用户发送信息
-     */
-    public void sendToUser(String userId, String message) {
-        Session session = sessionPools.get(userId);
-        try {
-            if (session != null) {
-                sendMessage(session, message);
-            } else {
-                log.debug("推送用户不在线");
-            }
-        } catch (Exception e) {
-            e.printStackTrace();
-        }
-    }
-
-    public static void addOnlineCount() {
-        onlineNum.incrementAndGet();
-    }
-
-    public static void subOnlineCount() {
-        onlineNum.decrementAndGet();
-
-    }
-}
+}

+ 120 - 0
ruoyi-framework/src/main/java/com/ruoyi/framework/webSocket/WebSocketUsers.java

@@ -0,0 +1,120 @@
+package com.ruoyi.framework.webSocket;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.websocket.Session;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * websocket 客户端用户集
+ *
+ * @author ruoyi
+ */
+public class WebSocketUsers {
+    /**
+     * WebSocketUsers 日志控制器
+     */
+    private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class);
+
+    /**
+     * 用户集
+     */
+    private static Map<String, Session> USERS = new ConcurrentHashMap<String, Session>();
+
+
+
+    /**
+     * 存储用户
+     *
+     * @param key     唯一键
+     * @param session 用户信息
+     */
+    public static void put(String key, Session session) {
+        USERS.put(key, session);
+    }
+
+    /**
+     * 移除用户
+     *
+     * @param session 用户信息
+     * @return 移除结果
+     */
+    public static boolean remove(Session session) {
+        String key = null;
+        boolean flag = USERS.containsValue(session);
+        if (flag) {
+            Set<Map.Entry<String, Session>> entries = USERS.entrySet();
+            for (Map.Entry<String, Session> entry : entries) {
+                Session value = entry.getValue();
+                if (value.equals(session)) {
+                    key = entry.getKey();
+                    break;
+                }
+            }
+        } else {
+            return true;
+        }
+        return remove(key);
+    }
+
+    /**
+     * 移出用户
+     *
+     * @param key 键
+     */
+    public static boolean remove(String key) {
+        LOGGER.info("\n 正在移出用户 - {}", key);
+        Session remove = USERS.remove(key);
+        if (remove != null) {
+            boolean containsValue = USERS.containsValue(remove);
+            LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功");
+            return containsValue;
+        } else {
+            return true;
+        }
+    }
+
+    /**
+     * 获取在线用户列表
+     *
+     * @return 返回用户集合
+     */
+    public static Map<String, Session> getUsers() {
+        return USERS;
+    }
+
+    /**
+     * 群发消息文本消息
+     *
+     * @param message 消息内容
+     */
+    public static void sendMessageToUsersByText(String message) {
+        Collection<Session> values = USERS.values();
+        for (Session value : values) {
+            sendMessageToUserByText(value, message);
+        }
+    }
+
+    /**
+     * 发送文本消息
+     *
+     * @param session 自己的用户名
+     * @param message 消息内容
+     */
+    public static void sendMessageToUserByText(Session session, String message) {
+        if (session != null) {
+            try {
+                session.getBasicRemote().sendText(message);
+            } catch (IOException e) {
+                LOGGER.error("\n[发送消息异常]", e);
+            }
+        } else {
+            LOGGER.info("\n[你已离线]");
+        }
+    }
+}

+ 5 - 2
ruoyi-generator/src/main/java/com/ruoyi/generator/controller/CommunityChatMsgController.java

@@ -33,12 +33,15 @@ public class CommunityChatMsgController extends BaseController {
      */
     @ApiOperation("获取当前登录用户聊天记录")
     @GetMapping()
-    public AjaxResult chatRecord() {
+    public AjaxResult chatRecord(Long otherUserId) {
         Long userId = SecurityUtils.getUserId();
         List<CommunityChatMsg> chatMsgList = communityChatMsgService.list(new QueryWrapper<CommunityChatMsg>()
                 .eq("sender_id", userId)
+                .eq("receiver_id", otherUserId)
                 .or()
-                .eq("receiver_id", userId).orderByDesc("create_time"));
+                .eq("sender_id", otherUserId)
+                .eq("receiver_id", userId)
+                .orderByDesc("create_time"));
         return AjaxResult.success(chatMsgList);
     }
 }

+ 1 - 1
ruoyi-generator/src/main/java/com/ruoyi/generator/controller/CommunityCommentController.java

@@ -182,7 +182,7 @@ public class CommunityCommentController extends BaseController {
         communityCommentReply.setCreateBy(userId);
         communityCommentReply.setCreateTime(DateUtils.parseDate(DateUtils.getTime()));
         communityCommentReplyService.save(communityCommentReply);
-        return AjaxResult.success("新增回复成功!");
+        return AjaxResult.success(communityCommentReply);
     }
 
     /**

+ 4 - 0
ruoyi-system/src/main/java/com/ruoyi/system/domain/CommunityChatMsg.java

@@ -63,6 +63,10 @@ public class CommunityChatMsg implements Serializable {
      * 消息类型
      */
     private int messageType;
+    /**
+     * 是否已读
+     */
+    private boolean isRead;
     /**
     * 创建时间
     */