修复热登陆时cookie无效的问题 🐛

This commit is contained in:
eatMoreApple 2021-04-28 10:38:52 +08:00
parent a8c646b33d
commit cd802c3a8d
4 changed files with 87 additions and 41 deletions

14
bot.go
View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net/url"
) )
type Bot struct { type Bot struct {
@ -18,7 +19,6 @@ type Bot struct {
self *Self self *Self
storage *Storage storage *Storage
hotReloadStorage HotReloadStorage hotReloadStorage HotReloadStorage
mode mode
} }
// 判断当前用户是否正常在线 // 判断当前用户是否正常在线
@ -49,8 +49,13 @@ func (b *Bot) HotLogin(storage HotReloadStorage) error {
return b.Login() return b.Login()
} }
cookies := storage.GetCookie() cookies := storage.GetCookie()
path := b.Caller.Client.getBaseUrl() for u, ck := range cookies {
b.Caller.Client.Jar.SetCookies(path, cookies) path, err := url.Parse(u)
if err != nil {
return err
}
b.Caller.Client.Jar.SetCookies(path, ck)
}
b.storage.LoginInfo = storage.GetLoginInfo() b.storage.LoginInfo = storage.GetLoginInfo()
b.storage.Request = storage.GetBaseRequest() b.storage.Request = storage.GetBaseRequest()
return b.webInit() return b.webInit()
@ -123,8 +128,9 @@ func (b *Bot) login(data []byte) error {
// 将BaseRequest存到storage里面方便后续调用 // 将BaseRequest存到storage里面方便后续调用
b.storage.Request = request b.storage.Request = request
// 如果是热登陆,则将当前的重要信息写入hotReloadStorage
if b.isHot { if b.isHot {
cookies := b.Caller.Client.getCookies() cookies := b.Caller.Client.GetCookieMap()
if err := b.hotReloadStorage.Dump(cookies, request, info); err != nil { if err := b.hotReloadStorage.Dump(cookies, request, info); err != nil {
return err return err
} }

View File

@ -222,8 +222,9 @@ func TestAgreeFriendsAdd(t *testing.T) {
} }
func TestHotLogin(t *testing.T) { func TestHotLogin(t *testing.T) {
filename := "test.json"
bot := defaultBot() bot := defaultBot()
s := NewFileHotReloadStorage("2.json") s := NewJsonFileHotReloadStorage(filename)
if err := bot.HotLogin(s); err != nil { if err := bot.HotLogin(s); err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -12,6 +12,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@ -21,6 +22,8 @@ import (
type Client struct { type Client struct {
*http.Client *http.Client
UrlManager UrlManager
mu sync.Mutex
cookies map[string][]*http.Cookie
} }
func NewClient(client *http.Client, urlManager UrlManager) *Client { func NewClient(client *http.Client, urlManager UrlManager) *Client {
@ -40,14 +43,26 @@ func DefaultClient(urlManager UrlManager) *Client {
return NewClient(client, urlManager) return NewClient(client, urlManager)
} }
func (c *Client) getBaseUrl() *url.URL { // 抽象Do方法,将所有的有效的cookie存入Client.cookies
path, _ := url.Parse(c.UrlManager.baseUrl) // 方便热登陆时获取
return path func (c *Client) Do(req *http.Request) (*http.Response, error) {
resp, err := c.Client.Do(req)
if err != nil {
return resp, err
}
c.mu.Lock()
defer c.mu.Unlock()
cookies := resp.Cookies()
if c.cookies == nil {
c.cookies = make(map[string][]*http.Cookie)
}
c.cookies[resp.Request.URL.String()] = cookies
return resp, err
} }
func (c *Client) getCookies() []*http.Cookie { // 获取当前client的所有的有效的client
path := c.getBaseUrl() func (c *Client) GetCookieMap() map[string][]*http.Cookie {
return c.Jar.Cookies(path) return c.cookies
} }
// 获取登录的uuid // 获取登录的uuid
@ -60,7 +75,8 @@ func (c *Client) GetLoginUUID() (*http.Response, error) {
params.Add("lang", "zh_CN") params.Add("lang", "zh_CN")
params.Add("_", strconv.FormatInt(time.Now().Unix(), 10)) params.Add("_", strconv.FormatInt(time.Now().Unix(), 10))
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 获取登录的二维吗 // 获取登录的二维吗
@ -80,7 +96,8 @@ func (c *Client) CheckLogin(uuid string) (*http.Response, error) {
params.Add("uuid", uuid) params.Add("uuid", uuid)
params.Add("tip", "0") params.Add("tip", "0")
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// GetLoginInfo 请求获取LoginInfo // GetLoginInfo 请求获取LoginInfo
@ -102,7 +119,9 @@ func (c *Client) WebInit(request *BaseRequest) (*http.Response, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.Post(path.String(), jsonContentType, body) req, _ := http.NewRequest(http.MethodPost, path.String(), body)
req.Header.Add("Content-Type", jsonContentType)
return c.Do(req)
} }
// 通知手机已登录 // 通知手机已登录
@ -157,7 +176,8 @@ func (c *Client) WebWxGetContact(info *LoginInfo) (*http.Response, error) {
params.Add("skey", info.SKey) params.Add("skey", info.SKey)
params.Add("req", "0") params.Add("req", "0")
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 获取联系人详情 // 获取联系人详情
@ -226,7 +246,8 @@ func (c *Client) WebWxSendMsg(msg *SendMessage, info *LoginInfo, request *BaseRe
// 获取用户的头像 // 获取用户的头像
func (c *Client) WebWxGetHeadImg(headImageUrl string) (*http.Response, error) { func (c *Client) WebWxGetHeadImg(headImageUrl string) (*http.Response, error) {
path := c.baseUrl + headImageUrl path := c.baseUrl + headImageUrl
return c.Get(path) req, _ := http.NewRequest(http.MethodGet, path, nil)
return c.Do(req)
} }
// 上传文件 // 上传文件
@ -291,7 +312,9 @@ func (c *Client) WebWxUploadMedia(file *os.File, request *BaseRequest, info *Log
if err = writer.Close(); err != nil { if err = writer.Close(); err != nil {
return nil, err return nil, err
} }
return c.Post(path.String(), ct, body) req, _ := http.NewRequest(http.MethodPost, path.String(), body)
req.Header.Set("Content-Type", ct)
return c.Do(req)
} }
// 发送图片 // 发送图片
@ -374,7 +397,8 @@ func (c *Client) WebWxGetMsgImg(msg *Message, info *LoginInfo) (*http.Response,
params.Add("skey", info.SKey) params.Add("skey", info.SKey)
params.Add("type", "slave") params.Add("type", "slave")
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 获取语音消息的语音响应 // 获取语音消息的语音响应
@ -384,7 +408,8 @@ func (c *Client) WebWxGetVoice(msg *Message, info *LoginInfo) (*http.Response, e
params.Add("msgid", msg.MsgId) params.Add("msgid", msg.MsgId)
params.Add("skey", info.SKey) params.Add("skey", info.SKey)
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 获取视频消息的视频响应 // 获取视频消息的视频响应
@ -394,7 +419,8 @@ func (c *Client) WebWxGetVideo(msg *Message, info *LoginInfo) (*http.Response, e
params.Add("msgid", msg.MsgId) params.Add("msgid", msg.MsgId)
params.Add("skey", info.SKey) params.Add("skey", info.SKey)
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 获取文件消息的文件响应 // 获取文件消息的文件响应
@ -408,7 +434,8 @@ func (c *Client) WebWxGetMedia(msg *Message, info *LoginInfo) (*http.Response, e
params.Add("pass_ticket", info.PassTicket) params.Add("pass_ticket", info.PassTicket)
params.Add("webwx_data_ticket", getWebWxDataTicket(c.Jar.Cookies(path))) params.Add("webwx_data_ticket", getWebWxDataTicket(c.Jar.Cookies(path)))
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 用户退出 // 用户退出
@ -419,7 +446,8 @@ func (c *Client) Logout(info *LoginInfo) (*http.Response, error) {
params.Add("type", "1") params.Add("type", "1")
params.Add("skey", info.SKey) params.Add("skey", info.SKey)
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
return c.Get(path.String()) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
return c.Do(req)
} }
// 添加用户进群聊 // 添加用户进群聊
@ -440,7 +468,9 @@ func (c *Client) AddMemberIntoChatRoom(req *BaseRequest, info *LoginInfo, group
"AddMemberList": strings.Join(addMemberList, ","), "AddMemberList": strings.Join(addMemberList, ","),
} }
buffer, _ := ToBuffer(content) buffer, _ := ToBuffer(content)
return c.Post(path.String(), jsonContentType, buffer) requ, _ := http.NewRequest(http.MethodPost, path.String(), buffer)
requ.Header.Set("Content-Type", jsonContentType)
return c.Do(requ)
} }
// 从群聊中移除用户 // 从群聊中移除用户
@ -460,5 +490,7 @@ func (c *Client) RemoveMemberFromChatRoom(req *BaseRequest, info *LoginInfo, gro
"DelMemberList": strings.Join(delMemberList, ","), "DelMemberList": strings.Join(delMemberList, ","),
} }
buffer, _ := ToBuffer(content) buffer, _ := ToBuffer(content)
return c.Post(path.String(), jsonContentType, buffer) requ, _ := http.NewRequest(http.MethodPost, path.String(), buffer)
requ.Header.Set("Content-Type", jsonContentType)
return c.Do(requ)
} }

View File

@ -7,31 +7,33 @@ import (
"os" "os"
) )
// 身份信息, 维持整个登陆的Session会话
type Storage struct { type Storage struct {
LoginInfo *LoginInfo LoginInfo *LoginInfo
Request *BaseRequest Request *BaseRequest
Response *WebInitResponse Response *WebInitResponse
} }
// 热登陆存储接口
type HotReloadStorage interface { type HotReloadStorage interface {
GetCookie() []*http.Cookie GetCookie() map[string][]*http.Cookie // 获取client.cookie
GetBaseRequest() *BaseRequest GetBaseRequest() *BaseRequest // 获取BaseRequest
GetLoginInfo() *LoginInfo GetLoginInfo() *LoginInfo // 获取LoginInfo
Dump(cookies []*http.Cookie, req *BaseRequest, info *LoginInfo) error Dump(cookies map[string][]*http.Cookie, req *BaseRequest, info *LoginInfo) error // 实现该方法, 将必要信息进行序列化
Load() error Load() error // 实现该方法, 将存储媒介的内容反序列化
} }
type FileHotReloadStorage struct { // 实现HotReloadStorage接口
Cookie []*http.Cookie // 默认以json文件的形式存储
type JsonFileHotReloadStorage struct {
Cookie map[string][]*http.Cookie
Req *BaseRequest Req *BaseRequest
Info *LoginInfo Info *LoginInfo
filename string filename string
} }
func (f *FileHotReloadStorage) Dump(cookies []*http.Cookie, req *BaseRequest, info *LoginInfo) error { // 将信息写入json文件
f.Cookie = cookies func (f *JsonFileHotReloadStorage) Dump(cookies map[string][]*http.Cookie, req *BaseRequest, info *LoginInfo) error {
f.Req = req
f.Info = info
var ( var (
file *os.File file *os.File
err error err error
@ -55,6 +57,10 @@ func (f *FileHotReloadStorage) Dump(cookies []*http.Cookie, req *BaseRequest, in
} }
defer file.Close() defer file.Close()
f.Cookie = cookies
f.Req = req
f.Info = info
data, err := json.Marshal(f) data, err := json.Marshal(f)
if err != nil { if err != nil {
return err return err
@ -63,7 +69,8 @@ func (f *FileHotReloadStorage) Dump(cookies []*http.Cookie, req *BaseRequest, in
return err return err
} }
func (f *FileHotReloadStorage) Load() error { // 从文件中读取信息
func (f *JsonFileHotReloadStorage) Load() error {
file, err := os.Open(f.filename) file, err := os.Open(f.filename)
if err != nil { if err != nil {
return err return err
@ -76,18 +83,18 @@ func (f *FileHotReloadStorage) Load() error {
return json.Unmarshal(buffer.Bytes(), f) return json.Unmarshal(buffer.Bytes(), f)
} }
func (f FileHotReloadStorage) GetCookie() []*http.Cookie { func (f *JsonFileHotReloadStorage) GetCookie() map[string][]*http.Cookie {
return f.Cookie return f.Cookie
} }
func (f FileHotReloadStorage) GetBaseRequest() *BaseRequest { func (f *JsonFileHotReloadStorage) GetBaseRequest() *BaseRequest {
return f.Req return f.Req
} }
func (f FileHotReloadStorage) GetLoginInfo() *LoginInfo { func (f *JsonFileHotReloadStorage) GetLoginInfo() *LoginInfo {
return f.Info return f.Info
} }
func NewFileHotReloadStorage(filename string) *FileHotReloadStorage { func NewJsonFileHotReloadStorage(filename string) *JsonFileHotReloadStorage {
return &FileHotReloadStorage{filename: filename} return &JsonFileHotReloadStorage{filename: filename}
} }