From a440e80d726d03f9ed3553adc847f6c365779239 Mon Sep 17 00:00:00 2001 From: eatmoreapple <15055461510@163.com> Date: Sun, 1 Aug 2021 13:46:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9HotReloadStorage=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.go | 56 ++++++++++++++++++++++++++++++++---------------- bot_test.go | 14 ++++++++++++ stroage.go | 61 +++++++++++++++++------------------------------------ 3 files changed, 71 insertions(+), 60 deletions(-) diff --git a/bot.go b/bot.go index d3eb0bc..911d3db 100644 --- a/bot.go +++ b/bot.go @@ -1,7 +1,9 @@ package openwechat import ( + "bytes" "context" + "encoding/json" "errors" "log" "net/url" @@ -65,13 +67,26 @@ func (b *Bot) HotLogin(storage HotReloadStorage, retry ...bool) error { // 如果load出错了,就执行正常登陆逻辑 // 第一次没有数据load都会出错的 - if err = storage.Load(); err != nil { + var buffer bytes.Buffer + if _, err := buffer.ReadFrom(storage); err != nil { return b.Login() } - if err = b.hotLoginInit(); err != nil { + var item HotReloadStorageItem + if err = json.NewDecoder(&buffer).Decode(&item); err != nil { return err } + cookies := item.Cookies + for u, ck := range cookies { + path, err := url.Parse(u) + if err != nil { + return err + } + b.Caller.Client.Jar.SetCookies(path, ck) + } + b.storage.LoginInfo = item.LoginInfo + b.storage.Request = item.BaseRequest + b.Caller.Client.domain = item.WechatDomain // 如果webInit出错,则说明可能身份信息已经失效 // 如果retry为True的话,则进行正常登陆 @@ -84,21 +99,21 @@ func (b *Bot) HotLogin(storage HotReloadStorage, retry ...bool) error { } // 热登陆初始化 -func (b *Bot) hotLoginInit() error { - item := b.hotReloadStorage.GetHotReloadStorageItem() - cookies := item.Cookies - for u, ck := range cookies { - path, err := url.Parse(u) - if err != nil { - return err - } - b.Caller.Client.Jar.SetCookies(path, ck) - } - b.storage.LoginInfo = item.LoginInfo - b.storage.Request = item.BaseRequest - b.Caller.Client.domain = item.WechatDomain - return nil -} +//func (b *Bot) hotLoginInit() error { +// item := b.hotReloadStorage.GetHotReloadStorageItem() +// cookies := item.Cookies +// for u, ck := range cookies { +// path, err := url.Parse(u) +// if err != nil { +// return err +// } +// b.Caller.Client.Jar.SetCookies(path, ck) +// } +// b.storage.LoginInfo = item.LoginInfo +// b.storage.Request = item.BaseRequest +// b.Caller.Client.domain = item.WechatDomain +// return nil +//} // Login 用户登录 // 该方法会一直阻塞,直到用户扫码登录,或者二维码过期 @@ -307,7 +322,12 @@ func (b *Bot) DumpHotReloadStorage() error { LoginInfo: b.storage.LoginInfo, WechatDomain: b.Caller.Client.domain, } - return b.hotReloadStorage.Dump(item) + data, err := json.Marshal(item) + if err != nil { + return err + } + _, err = b.hotReloadStorage.Write(data) + return err } // OnLogin is a setter for LoginCallBack diff --git a/bot_test.go b/bot_test.go index 461f95c..1e1a72d 100644 --- a/bot_test.go +++ b/bot_test.go @@ -127,3 +127,17 @@ func TestSender(t *testing.T) { } bot.Block() } + +func TestHotReloadStorage(t *testing.T) { + bot := DefaultBot(Desktop) + bot.MessageHandler = func(msg *Message) { + if msg.IsText() && msg.Content == "ping" { + msg.ReplyText("pong") + } + } + if err := bot.HotLogin(NewJsonFileHotReloadStorage("test.json")); err != nil { + t.Error(err) + return + } + bot.Block() +} diff --git a/stroage.go b/stroage.go index 5751c5f..a801690 100644 --- a/stroage.go +++ b/stroage.go @@ -1,8 +1,7 @@ package openwechat import ( - "bytes" - "encoding/json" + "io" "net/http" "os" ) @@ -22,60 +21,38 @@ type HotReloadStorageItem struct { } // HotReloadStorage 热登陆存储接口 -type HotReloadStorage interface { - GetHotReloadStorageItem() HotReloadStorageItem // 获取HotReloadStorageItem - Dump(item HotReloadStorageItem) error // 实现该方法, 将必要信息进行序列化 - Load() error // 实现该方法, 将存储媒介的内容反序列化 -} +type HotReloadStorage io.ReadWriter // JsonFileHotReloadStorage 实现HotReloadStorage接口 // 默认以json文件的形式存储 type JsonFileHotReloadStorage struct { - item HotReloadStorageItem - filename string + FileName string + file *os.File } -// Dump 将信息写入json文件 -func (f *JsonFileHotReloadStorage) Dump(item HotReloadStorageItem) error { - - file, err := os.OpenFile(f.filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.ModePerm) - - if err != nil { - return err +func (j *JsonFileHotReloadStorage) Read(p []byte) (n int, err error) { + if j.file == nil { + j.file, err = os.Open(j.FileName) + if err != nil { + return 0, err + } } - - defer file.Close() - - f.item = item - - data, err := json.Marshal(f.item) - if err != nil { - return err + n, err = j.file.Read(p) + if err == io.EOF { + j.file.Close() } - _, err = file.Write(data) - return err + return n, err } -// Load 从文件中读取信息 -func (f *JsonFileHotReloadStorage) Load() error { - file, err := os.Open(f.filename) - +func (j *JsonFileHotReloadStorage) Write(p []byte) (n int, err error) { + file, err := os.Create(j.FileName) if err != nil { - return err + return 0, err } defer file.Close() - var buffer bytes.Buffer - if _, err := buffer.ReadFrom(file); err != nil { - return err - } - err = json.Unmarshal(buffer.Bytes(), &f.item) - return err -} - -func (f *JsonFileHotReloadStorage) GetHotReloadStorageItem() HotReloadStorageItem { - return f.item + return file.Write(p) } func NewJsonFileHotReloadStorage(filename string) *JsonFileHotReloadStorage { - return &JsonFileHotReloadStorage{filename: filename} + return &JsonFileHotReloadStorage{FileName: filename} }