diff --git a/bot.go b/bot.go index d2d9256..a74894f 100644 --- a/bot.go +++ b/bot.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "net/url" ) type Bot struct { @@ -18,7 +19,6 @@ type Bot struct { self *Self storage *Storage hotReloadStorage HotReloadStorage - mode mode } // 判断当前用户是否正常在线 @@ -49,8 +49,13 @@ func (b *Bot) HotLogin(storage HotReloadStorage) error { return b.Login() } cookies := storage.GetCookie() - path := b.Caller.Client.getBaseUrl() - b.Caller.Client.Jar.SetCookies(path, cookies) + for u, ck := range cookies { + path, err := url.Parse(u) + if err != nil { + return err + } + b.Caller.Client.Jar.SetCookies(path, ck) + } b.storage.LoginInfo = storage.GetLoginInfo() b.storage.Request = storage.GetBaseRequest() return b.webInit() @@ -123,8 +128,9 @@ func (b *Bot) login(data []byte) error { // 将BaseRequest存到storage里面方便后续调用 b.storage.Request = request + // 如果是热登陆,则将当前的重要信息写入hotReloadStorage if b.isHot { - cookies := b.Caller.Client.getCookies() + cookies := b.Caller.Client.GetCookieMap() if err := b.hotReloadStorage.Dump(cookies, request, info); err != nil { return err } diff --git a/bot_test.go b/bot_test.go index e9bdd3d..43e2f4a 100644 --- a/bot_test.go +++ b/bot_test.go @@ -222,8 +222,9 @@ func TestAgreeFriendsAdd(t *testing.T) { } func TestHotLogin(t *testing.T) { + filename := "test.json" bot := defaultBot() - s := NewFileHotReloadStorage("2.json") + s := NewJsonFileHotReloadStorage(filename) if err := bot.HotLogin(s); err != nil { t.Error(err) return diff --git a/client.go b/client.go index eeaf24a..93c3491 100644 --- a/client.go +++ b/client.go @@ -12,6 +12,7 @@ import ( "os" "strconv" "strings" + "sync" "time" ) @@ -21,6 +22,8 @@ import ( type Client struct { *http.Client UrlManager + mu sync.Mutex + cookies map[string][]*http.Cookie } func NewClient(client *http.Client, urlManager UrlManager) *Client { @@ -40,14 +43,26 @@ func DefaultClient(urlManager UrlManager) *Client { return NewClient(client, urlManager) } -func (c *Client) getBaseUrl() *url.URL { - path, _ := url.Parse(c.UrlManager.baseUrl) - return path +// 抽象Do方法,将所有的有效的cookie存入Client.cookies +// 方便热登陆时获取 +func (c *Client) Do(req *http.Request) (*http.Response, error) { + resp, err := c.Client.Do(req) + if err != nil { + return resp, err + } + c.mu.Lock() + defer c.mu.Unlock() + cookies := resp.Cookies() + if c.cookies == nil { + c.cookies = make(map[string][]*http.Cookie) + } + c.cookies[resp.Request.URL.String()] = cookies + return resp, err } -func (c *Client) getCookies() []*http.Cookie { - path := c.getBaseUrl() - return c.Jar.Cookies(path) +// 获取当前client的所有的有效的client +func (c *Client) GetCookieMap() map[string][]*http.Cookie { + return c.cookies } // 获取登录的uuid @@ -60,7 +75,8 @@ func (c *Client) GetLoginUUID() (*http.Response, error) { params.Add("lang", "zh_CN") params.Add("_", strconv.FormatInt(time.Now().Unix(), 10)) path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 获取登录的二维吗 @@ -80,7 +96,8 @@ func (c *Client) CheckLogin(uuid string) (*http.Response, error) { params.Add("uuid", uuid) params.Add("tip", "0") path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // GetLoginInfo 请求获取LoginInfo @@ -102,7 +119,9 @@ func (c *Client) WebInit(request *BaseRequest) (*http.Response, error) { if err != nil { return nil, err } - return c.Post(path.String(), jsonContentType, body) + req, _ := http.NewRequest(http.MethodPost, path.String(), body) + req.Header.Add("Content-Type", jsonContentType) + return c.Do(req) } // 通知手机已登录 @@ -157,7 +176,8 @@ func (c *Client) WebWxGetContact(info *LoginInfo) (*http.Response, error) { params.Add("skey", info.SKey) params.Add("req", "0") path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 获取联系人详情 @@ -226,7 +246,8 @@ func (c *Client) WebWxSendMsg(msg *SendMessage, info *LoginInfo, request *BaseRe // 获取用户的头像 func (c *Client) WebWxGetHeadImg(headImageUrl string) (*http.Response, error) { path := c.baseUrl + headImageUrl - return c.Get(path) + req, _ := http.NewRequest(http.MethodGet, path, nil) + return c.Do(req) } // 上传文件 @@ -291,7 +312,9 @@ func (c *Client) WebWxUploadMedia(file *os.File, request *BaseRequest, info *Log if err = writer.Close(); err != nil { return nil, err } - return c.Post(path.String(), ct, body) + req, _ := http.NewRequest(http.MethodPost, path.String(), body) + req.Header.Set("Content-Type", ct) + return c.Do(req) } // 发送图片 @@ -374,7 +397,8 @@ func (c *Client) WebWxGetMsgImg(msg *Message, info *LoginInfo) (*http.Response, params.Add("skey", info.SKey) params.Add("type", "slave") path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 获取语音消息的语音响应 @@ -384,7 +408,8 @@ func (c *Client) WebWxGetVoice(msg *Message, info *LoginInfo) (*http.Response, e params.Add("msgid", msg.MsgId) params.Add("skey", info.SKey) path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 获取视频消息的视频响应 @@ -394,7 +419,8 @@ func (c *Client) WebWxGetVideo(msg *Message, info *LoginInfo) (*http.Response, e params.Add("msgid", msg.MsgId) params.Add("skey", info.SKey) path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 获取文件消息的文件响应 @@ -408,7 +434,8 @@ func (c *Client) WebWxGetMedia(msg *Message, info *LoginInfo) (*http.Response, e params.Add("pass_ticket", info.PassTicket) params.Add("webwx_data_ticket", getWebWxDataTicket(c.Jar.Cookies(path))) path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 用户退出 @@ -419,7 +446,8 @@ func (c *Client) Logout(info *LoginInfo) (*http.Response, error) { params.Add("type", "1") params.Add("skey", info.SKey) path.RawQuery = params.Encode() - return c.Get(path.String()) + req, _ := http.NewRequest(http.MethodGet, path.String(), nil) + return c.Do(req) } // 添加用户进群聊 @@ -440,7 +468,9 @@ func (c *Client) AddMemberIntoChatRoom(req *BaseRequest, info *LoginInfo, group "AddMemberList": strings.Join(addMemberList, ","), } buffer, _ := ToBuffer(content) - return c.Post(path.String(), jsonContentType, buffer) + requ, _ := http.NewRequest(http.MethodPost, path.String(), buffer) + requ.Header.Set("Content-Type", jsonContentType) + return c.Do(requ) } // 从群聊中移除用户 @@ -460,5 +490,7 @@ func (c *Client) RemoveMemberFromChatRoom(req *BaseRequest, info *LoginInfo, gro "DelMemberList": strings.Join(delMemberList, ","), } buffer, _ := ToBuffer(content) - return c.Post(path.String(), jsonContentType, buffer) + requ, _ := http.NewRequest(http.MethodPost, path.String(), buffer) + requ.Header.Set("Content-Type", jsonContentType) + return c.Do(requ) } diff --git a/stroage.go b/stroage.go index ed2d50c..dd34b4e 100644 --- a/stroage.go +++ b/stroage.go @@ -7,31 +7,33 @@ import ( "os" ) +// 身份信息, 维持整个登陆的Session会话 type Storage struct { LoginInfo *LoginInfo Request *BaseRequest Response *WebInitResponse } +// 热登陆存储接口 type HotReloadStorage interface { - GetCookie() []*http.Cookie - GetBaseRequest() *BaseRequest - GetLoginInfo() *LoginInfo - Dump(cookies []*http.Cookie, req *BaseRequest, info *LoginInfo) error - Load() error + GetCookie() map[string][]*http.Cookie // 获取client.cookie + GetBaseRequest() *BaseRequest // 获取BaseRequest + GetLoginInfo() *LoginInfo // 获取LoginInfo + Dump(cookies map[string][]*http.Cookie, req *BaseRequest, info *LoginInfo) error // 实现该方法, 将必要信息进行序列化 + Load() error // 实现该方法, 将存储媒介的内容反序列化 } -type FileHotReloadStorage struct { - Cookie []*http.Cookie +// 实现HotReloadStorage接口 +// 默认以json文件的形式存储 +type JsonFileHotReloadStorage struct { + Cookie map[string][]*http.Cookie Req *BaseRequest Info *LoginInfo filename string } -func (f *FileHotReloadStorage) Dump(cookies []*http.Cookie, req *BaseRequest, info *LoginInfo) error { - f.Cookie = cookies - f.Req = req - f.Info = info +// 将信息写入json文件 +func (f *JsonFileHotReloadStorage) Dump(cookies map[string][]*http.Cookie, req *BaseRequest, info *LoginInfo) error { var ( file *os.File err error @@ -55,6 +57,10 @@ func (f *FileHotReloadStorage) Dump(cookies []*http.Cookie, req *BaseRequest, in } defer file.Close() + f.Cookie = cookies + f.Req = req + f.Info = info + data, err := json.Marshal(f) if err != nil { return err @@ -63,7 +69,8 @@ func (f *FileHotReloadStorage) Dump(cookies []*http.Cookie, req *BaseRequest, in return err } -func (f *FileHotReloadStorage) Load() error { +// 从文件中读取信息 +func (f *JsonFileHotReloadStorage) Load() error { file, err := os.Open(f.filename) if err != nil { return err @@ -76,18 +83,18 @@ func (f *FileHotReloadStorage) Load() error { return json.Unmarshal(buffer.Bytes(), f) } -func (f FileHotReloadStorage) GetCookie() []*http.Cookie { +func (f *JsonFileHotReloadStorage) GetCookie() map[string][]*http.Cookie { return f.Cookie } -func (f FileHotReloadStorage) GetBaseRequest() *BaseRequest { +func (f *JsonFileHotReloadStorage) GetBaseRequest() *BaseRequest { return f.Req } -func (f FileHotReloadStorage) GetLoginInfo() *LoginInfo { +func (f *JsonFileHotReloadStorage) GetLoginInfo() *LoginInfo { return f.Info } -func NewFileHotReloadStorage(filename string) *FileHotReloadStorage { - return &FileHotReloadStorage{filename: filename} +func NewJsonFileHotReloadStorage(filename string) *JsonFileHotReloadStorage { + return &JsonFileHotReloadStorage{filename: filename} }