package db import ( "encoding/hex" "fmt" "mtp2_if/config" "mtp2_if/packet" "gopkg.in/mgo.v2" ) var session *mgo.Session var mongodb *mgo.Database // InitMongoDB 初始化连接MongoDB func InitMongoDB() error { // 创建链接 var err error session, err = mgo.Dial(fmt.Sprintf("%s:%d", config.SerCfg.GetMongoDBHostname(), config.SerCfg.GetMongoDBPort())) if err != nil { return err } // 选择DB mongodb = session.DB(config.SerCfg.GetMongoDBDBName()) // 尝试解密 var dbUser []byte ciphertext, _ := hex.DecodeString(config.SerCfg.GetMongoDBUsername()) if len(ciphertext) > 8 { ciphertext = ciphertext[4 : len(ciphertext)-8] dbUser, _ = packet.Decrypt(ciphertext, packet.AESKey, true) if dbUser == nil { dbUser = []byte(config.SerCfg.GetMongoDBUsername()) } } else { dbUser = []byte(config.SerCfg.GetMongoDBUsername()) } var dbPwd []byte ciphertext, _ = hex.DecodeString(config.SerCfg.GetMongoDBPassword()) if len(ciphertext) > 8 { ciphertext = ciphertext[4 : len(ciphertext)-8] dbPwd, _ = packet.Decrypt(ciphertext, packet.AESKey, true) if dbPwd == nil { dbPwd = []byte(config.SerCfg.GetMongoDBPassword()) } } else { dbPwd = []byte(config.SerCfg.GetMongoDBPassword()) } // 登陆 if err := mongodb.Login(string(dbUser), string(dbPwd)); err != nil { return err } return nil } // GetMongoDB 获取MongoDB Database func GetMongoDB() *mgo.Database { return mongodb } // CloseMongoDB 关闭MongoDB连接 func CloseMongoDB() { session.Close() }