mongodb.go 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. package db
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "mtp2_if/config"
  6. "mtp2_if/packet"
  7. "gopkg.in/mgo.v2"
  8. )
  9. var session *mgo.Session
  10. var mongodb *mgo.Database
  11. // InitMongoDB 初始化连接MongoDB
  12. func InitMongoDB() error {
  13. // 创建链接
  14. var err error
  15. session, err = mgo.Dial(fmt.Sprintf("%s:%d", config.SerCfg.GetMongoDBHostname(), config.SerCfg.GetMongoDBPort()))
  16. if err != nil {
  17. return err
  18. }
  19. // 选择DB
  20. mongodb = session.DB(config.SerCfg.GetMongoDBDBName())
  21. // 尝试解密
  22. var dbUser []byte
  23. ciphertext, _ := hex.DecodeString(config.SerCfg.GetMongoDBUsername())
  24. if len(ciphertext) > 8 {
  25. ciphertext = ciphertext[4 : len(ciphertext)-8]
  26. dbUser, _ = packet.Decrypt(ciphertext, packet.AESKey, true)
  27. if dbUser == nil {
  28. dbUser = []byte(config.SerCfg.GetMongoDBUsername())
  29. }
  30. } else {
  31. dbUser = []byte(config.SerCfg.GetMongoDBUsername())
  32. }
  33. var dbPwd []byte
  34. ciphertext, _ = hex.DecodeString(config.SerCfg.GetMongoDBPassword())
  35. if len(ciphertext) > 8 {
  36. ciphertext = ciphertext[4 : len(ciphertext)-8]
  37. dbPwd, _ = packet.Decrypt(ciphertext, packet.AESKey, true)
  38. if dbPwd == nil {
  39. dbPwd = []byte(config.SerCfg.GetMongoDBPassword())
  40. }
  41. } else {
  42. dbPwd = []byte(config.SerCfg.GetMongoDBPassword())
  43. }
  44. // 登陆
  45. if err := mongodb.Login(string(dbUser), string(dbPwd)); err != nil {
  46. return err
  47. }
  48. return nil
  49. }
  50. // GetMongoDB 获取MongoDB Database
  51. func GetMongoDB() *mgo.Database {
  52. return mongodb
  53. }
  54. // CloseMongoDB 关闭MongoDB连接
  55. func CloseMongoDB() {
  56. session.Close()
  57. }