package utils import ( "context" "errors" "fmt" "mtp20access/global" "mtp20access/model/common/request" "time" "github.com/golang-jwt/jwt/v4" ) type JWT struct { SigningKey []byte } var ( ErrTokenExpired = errors.New("token is expired") ErrTokenNotValidYet = errors.New("token not active yet") ErrTokenMalformed = errors.New("that's not even a token") ErrTokenInvalid = errors.New("couldn't handle this token ") ) func NewJWT() *JWT { return &JWT{ []byte(global.M2A_CONFIG.JWT.SigningKey), } } func (j *JWT) CreateClaims(baseClaims request.BaseClaims) request.CustomClaims { claims := request.CustomClaims{ BaseClaims: baseClaims, BufferTime: global.M2A_CONFIG.JWT.BufferTime, // 缓冲时间1天 缓冲时间内会获得新的token刷新令牌 此时一个用户会存在两个有效令牌 但是前端只留一个 另一个会丢失 RegisteredClaims: jwt.RegisteredClaims{ IssuedAt: jwt.NewNumericDate(time.Now()), // 签发时间 NotBefore: jwt.NewNumericDate(time.Now()), // 签名生效时间 ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(global.M2A_CONFIG.JWT.ExpiresTime))), // 过期时间 7天 配置文件 Issuer: global.M2A_CONFIG.JWT.Issuer, // 签名的发行者 }, } return claims } // 创建一个token func (j *JWT) CreateToken(claims request.CustomClaims) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(j.SigningKey) } // CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题 func (j *JWT) CreateTokenByOldToken(oldToken string, claims request.CustomClaims) (string, error) { v, err, _ := global.M2A_Concurrency_Control.Do("JWT:"+oldToken, func() (interface{}, error) { return j.CreateToken(claims) }) return v.(string), err } // 解析 token func (j *JWT) ParseToken(tokenString string) (*request.CustomClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &request.CustomClaims{}, func(token *jwt.Token) (i interface{}, e error) { return j.SigningKey, nil }) if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { if ve.Errors&jwt.ValidationErrorMalformed != 0 { return nil, ErrTokenMalformed } else if ve.Errors&jwt.ValidationErrorExpired != 0 { // Token is expired return nil, ErrTokenExpired } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { return nil, ErrTokenNotValidYet } else { return nil, ErrTokenInvalid } } } if token != nil { if claims, ok := token.Claims.(*request.CustomClaims); ok && token.Valid { return claims, nil } return nil, ErrTokenInvalid } else { return nil, ErrTokenInvalid } } //@function: GetRedis //@description: 从redis取目标用户登录信息 //@param: userName int //@return: redisJWT string, err error func (j *JWT) GetRedisLogin(loginID int, group int) (values map[string]string, err error) { key := fmt.Sprintf("m2a:login:%d:%d", loginID, group) values, err = global.M2A_REDIS.HGetAll(context.Background(), key).Result() return } //@function: SetRedisJWT //@description: jwt存入redis并设置过期时间 //@param: loginID string, jwt string //@return: err error func (j *JWT) SetRedisLogin(loginID int, group int, values map[string]interface{}) (err error) { key := fmt.Sprintf("m2a:login:%d:%d", loginID, group) // 此处过期时间等于jwt过期时间 // timer := time.Duration(global.M2A_CONFIG.JWT.ExpiresTime) * time.Second err = global.M2A_REDIS.HMSet(context.Background(), key, values).Err() return } // SetOriRedisToken 设置原登录服务Redis Token func (j *JWT) SetOriRedisToken(loginID int, group int) (err error) { key := fmt.Sprintf("monitor:online_loginid:%d:%d", loginID, group) // 生成原登录服务Token token := fmt.Sprintf("%d_%d_%d", loginID, time.Now().Unix(), group) err = global.M2A_REDIS.HSet(context.Background(), key, "Token", token).Err() return }