login.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package account
  2. import (
  3. "encoding/base64"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "mtp20access/client"
  8. "mtp20access/global"
  9. accountModel "mtp20access/model/account"
  10. "mtp20access/model/account/request"
  11. jwtRequest "mtp20access/model/common/request"
  12. "mtp20access/packet"
  13. "sync"
  14. "mtp20access/utils"
  15. "strconv"
  16. "go.uber.org/zap"
  17. )
  18. var (
  19. mtx sync.RWMutex
  20. curSessionID int = 90000 // 本服务SessionID从90000开始,以避免与旧登录服务重叠
  21. )
  22. // Login 用户登录
  23. func Login(req request.LoginReq, addr string) (loginaccount *accountModel.Loginaccount, token string, expiresAt int64, err error) {
  24. // 分别尝试用LoginID、LoginCode和手机号码进行登录
  25. loginaccount, err = getLoginAccount(req.UserName, req.Password)
  26. if err != nil {
  27. return
  28. }
  29. // 判断用户状态
  30. if loginaccount.LOGINSTATUS == 2 {
  31. err = errors.New("账户已冻结")
  32. return
  33. }
  34. if loginaccount.LOGINSTATUS == 3 {
  35. err = errors.New("账户已注销")
  36. return
  37. }
  38. // 生成Token,并写入Redis
  39. if token, expiresAt, err = buildRedisLoginInfo(*loginaccount, addr, req.ClientType); err != nil {
  40. return
  41. }
  42. return
  43. }
  44. // getLoginAccount 分别尝试用LoginID、LoginCode和手机号码进行登录
  45. func getLoginAccount(userName string, password string) (loginaccount *accountModel.Loginaccount, err error) {
  46. // 密码解密(5.0报文解密)
  47. d, err := base64.StdEncoding.DecodeString(password)
  48. if err != nil {
  49. return
  50. }
  51. d1 := d[4 : len(d)-8] // 解密时要去头尾
  52. p, err := packet.Decrypt(d1, packet.AESKey, true)
  53. if err != nil {
  54. return
  55. }
  56. pwd := string(p)
  57. // 通过LoginID查询
  58. if loginID, _ := strconv.Atoi(userName); loginID != 0 {
  59. loginaccount = &accountModel.Loginaccount{
  60. LOGINID: int64(loginID),
  61. PASSWORD: utils.EncoderSha256(fmt.Sprintf("%s%s", userName, pwd)), // 构建数据库存储的密码
  62. }
  63. if has, _ := global.M2A_DB.Get(loginaccount); has {
  64. return
  65. }
  66. }
  67. // 通过LoginCode查询
  68. loginaccount = &accountModel.Loginaccount{
  69. LOGINCODE: userName,
  70. }
  71. if has, _ := global.M2A_DB.Get(loginaccount); has {
  72. // 构建数据库存储的密码
  73. if loginaccount.PASSWORD == utils.EncoderSha256(fmt.Sprintf("%d%s", loginaccount.LOGINID, pwd)) {
  74. return
  75. }
  76. }
  77. // 通过手机号码查询,需要AES加密
  78. key, _ := hex.DecodeString(utils.AESSecretKey)
  79. if mobileEncrypted, _ := utils.AESEncrypt([]byte(userName), key); mobileEncrypted != nil {
  80. // 从三方表获取LoginID
  81. userauthinfo := &accountModel.Userauthinfo{
  82. AUTHID: string(mobileEncrypted),
  83. AUTHTYPE: 3,
  84. }
  85. if has, _ := global.M2A_DB.Get(&userauthinfo); has {
  86. loginaccount = &accountModel.Loginaccount{
  87. LOGINID: userauthinfo.LOGINID,
  88. PASSWORD: utils.EncoderSha256(fmt.Sprintf("%s%s", userName, pwd)), // 构建数据库存储的密码
  89. }
  90. if has, _ := global.M2A_DB.Get(loginaccount); has {
  91. return
  92. }
  93. }
  94. // loginaccount = &accountModel.Loginaccount{
  95. // MOBILE: string(mobileEncrypted),
  96. // }
  97. // if has, _ := global.M2A_DB.Get(loginaccount); has {
  98. // // 构建数据库存储的密码
  99. // if loginaccount.PASSWORD == utils.EncoderSha256(fmt.Sprintf("%d%s", loginaccount.LOGINID, pwd)) {
  100. // return
  101. // }
  102. // }
  103. }
  104. err = errors.New("错误的用户名或密码")
  105. return
  106. }
  107. // newSessionID 获取
  108. func newSessionID() int {
  109. mtx.RLock()
  110. defer mtx.RUnlock()
  111. curSessionID += 1
  112. return curSessionID
  113. }
  114. // buildRedisLoginInfo 生成Token,并写入Redis
  115. func buildRedisLoginInfo(loginaccount accountModel.Loginaccount, addr string, group int) (token string, expiresAt int64, err error) {
  116. // 生成SessionID
  117. sessionID := newSessionID()
  118. // 生成本服务Token
  119. j := &utils.JWT{SigningKey: []byte(global.M2A_CONFIG.JWT.SigningKey)} // 唯一签名
  120. claims := j.CreateClaims(jwtRequest.BaseClaims{
  121. LoginID: int(loginaccount.LOGINID),
  122. Group: group,
  123. SessionID: sessionID,
  124. })
  125. token, err = j.CreateToken(claims)
  126. if err != nil {
  127. global.M2A_LOG.Error("生成本服token失败", zap.Error(err))
  128. return
  129. }
  130. expiresAt = claims.RegisteredClaims.ExpiresAt.Unix()
  131. loginLogin := client.LoginRedis{
  132. LoginID: strconv.Itoa(int(loginaccount.LOGINID)),
  133. UserID: strconv.Itoa(int(loginaccount.USERID)),
  134. SessionID: strconv.Itoa(sessionID),
  135. Token: token,
  136. Group: strconv.Itoa(group),
  137. Addr: addr,
  138. }
  139. loginMap, err := loginLogin.ToMap()
  140. // loginMap := map[string]interface{}{
  141. // "LoginID": strconv.Itoa(int(loginaccount.LOGINID)),
  142. // "UserID": strconv.Itoa(int(loginaccount.USERID)),
  143. // "SessionID": strconv.Itoa(sessionID),
  144. // "Token": token,
  145. // "Group": strconv.Itoa(group),
  146. // "Addr": addr,
  147. // }
  148. if err != nil {
  149. global.M2A_LOG.Error("生成登录信息MAP失败", zap.Error(err))
  150. return
  151. }
  152. if err = j.SetRedisLogin(int(loginaccount.LOGINID), group, loginMap); err != nil {
  153. global.M2A_LOG.Error("Token写入Redis失败", zap.Error(err))
  154. return
  155. }
  156. // 生成旧登录服务Token
  157. // if err = j.SetOriRedisToken(int(loginaccount.LOGINID), group); err != nil {
  158. // // FIXME: 这里有类事务的回滚问题
  159. // global.M2A_LOG.Error("生成旧登录服务Token失败", zap.Error(err))
  160. // return
  161. // }
  162. // 记录用户信息
  163. mtx.Lock()
  164. defer mtx.Unlock()
  165. if client.Clients == nil {
  166. client.Clients = make(map[int]*client.Client, 0)
  167. }
  168. delete(client.Clients, claims.SessionID)
  169. client.Clients[claims.SessionID] = &client.Client{LoginRedis: loginLogin}
  170. return
  171. }
  172. // RestoreLoginWithToken 通过Token检验恢复登录状态失败
  173. func RestoreLoginWithToken(loginID int, group int, token string) (err error) {
  174. // 从Redis获取登录信息
  175. j := utils.NewJWT()
  176. loginMap, err := j.GetRedisLogin(loginID, group)
  177. if err != nil {
  178. global.M2A_LOG.Error("Token检验恢复登录状态失败", zap.Error(err))
  179. return
  180. }
  181. loginId := loginMap["loginId"]
  182. userId := loginMap["userId"]
  183. sessionId := loginMap["sessionId"]
  184. addr := loginMap["addr"]
  185. loginLogin := client.LoginRedis{
  186. LoginID: loginId,
  187. UserID: userId,
  188. SessionID: sessionId,
  189. Token: token,
  190. Group: strconv.Itoa(group),
  191. Addr: addr,
  192. }
  193. // 记录用户信息
  194. mtx.Lock()
  195. defer mtx.Unlock()
  196. if client.Clients == nil {
  197. client.Clients = make(map[int]*client.Client, 0)
  198. }
  199. s, err := strconv.Atoi(sessionId)
  200. if err != nil {
  201. global.M2A_LOG.Error("Token检验恢复登录状态失败", zap.Error(err))
  202. return
  203. }
  204. delete(client.Clients, s)
  205. client.Clients[s] = &client.Client{LoginRedis: loginLogin}
  206. return
  207. }
  208. // GetClientsByAccountID 通过资金账户获取所有的
  209. func GetClientsByAccountID(accountID uint64) (clients []*client.Client, err error) {
  210. clients = make([]*client.Client, 0)
  211. loginIds := make([]string, 0)
  212. sql := fmt.Sprintf(`
  213. SELECT
  214. to_char(t.loginid)
  215. FROM loginaccount t
  216. INNER JOIN taaccount a ON a.userid = t.userid
  217. WHERE a.accountid = %v
  218. `, accountID)
  219. if err = global.M2A_DB.SQL(sql).Find(&loginIds); err != nil {
  220. global.M2A_LOG.Error("获取LoginID失败", zap.Error(err))
  221. return
  222. }
  223. var mtx sync.RWMutex
  224. mtx.Lock()
  225. defer mtx.Unlock()
  226. if len(loginIds) > 0 && len(client.Clients) > 0 {
  227. for _, item := range loginIds {
  228. for i := range client.Clients {
  229. c := client.Clients[i]
  230. if c.LoginID == item {
  231. clients = append(clients, c)
  232. }
  233. }
  234. }
  235. }
  236. return
  237. }