diff --git a/bot.go b/bot.go index bd55484..18877a5 100644 --- a/bot.go +++ b/bot.go @@ -71,23 +71,17 @@ func (b *Bot) HotLogin(storage HotReloadStorage, retry ...bool) error { if _, err := buffer.ReadFrom(storage); err != nil { return b.Login() } + defer b.HotReloadStorage.Close() var item HotReloadStorageItem + if err = json.NewDecoder(&buffer).Decode(&item); err != nil { return err } - cookies := item.Cookies - for u, ck := range cookies { - path, err := url.Parse(u) - if err != nil { - return err - } - b.Caller.Client.Jar.SetCookies(path, ck) + if err = b.hotLoginInit(item); err != nil { + return err } - b.storage.LoginInfo = item.LoginInfo - b.storage.Request = item.BaseRequest - b.Caller.Client.domain = item.WechatDomain // 如果webInit出错,则说明可能身份信息已经失效 // 如果retry为True的话,则进行正常登陆 @@ -100,21 +94,20 @@ func (b *Bot) HotLogin(storage HotReloadStorage, retry ...bool) error { } // 热登陆初始化 -//func (b *Bot) hotLoginInit() error { -// item := b.hotReloadStorage.GetHotReloadStorageItem() -// cookies := item.Cookies -// for u, ck := range cookies { -// path, err := url.Parse(u) -// if err != nil { -// return err -// } -// b.Caller.Client.Jar.SetCookies(path, ck) -// } -// b.storage.LoginInfo = item.LoginInfo -// b.storage.Request = item.BaseRequest -// b.Caller.Client.domain = item.WechatDomain -// return nil -//} +func (b *Bot) hotLoginInit(item HotReloadStorageItem) error { + cookies := item.Cookies + for u, ck := range cookies { + path, err := url.Parse(u) + if err != nil { + return err + } + b.Caller.Client.Jar.SetCookies(path, ck) + } + b.storage.LoginInfo = item.LoginInfo + b.storage.Request = item.BaseRequest + b.Caller.Client.domain = item.WechatDomain + return nil +} // Login 用户登录 // 该方法会一直阻塞,直到用户扫码登录,或者二维码过期 @@ -328,9 +321,10 @@ func (b *Bot) DumpHotReloadStorage() error { if err != nil { return err } - _, err = b.HotReloadStorage.Write(data) - return err - + if _, err = b.HotReloadStorage.Write(data); err != nil { + return err + } + return b.HotReloadStorage.Close() } // OnLogin is a setter for LoginCallBack diff --git a/stroage.go b/stroage.go index a801690..e59e534 100644 --- a/stroage.go +++ b/stroage.go @@ -21,7 +21,7 @@ type HotReloadStorageItem struct { } // HotReloadStorage 热登陆存储接口 -type HotReloadStorage io.ReadWriter +type HotReloadStorage io.ReadWriteCloser // JsonFileHotReloadStorage 实现HotReloadStorage接口 // 默认以json文件的形式存储 @@ -37,22 +37,26 @@ func (j *JsonFileHotReloadStorage) Read(p []byte) (n int, err error) { return 0, err } } - n, err = j.file.Read(p) - if err == io.EOF { - j.file.Close() - } - return n, err + return j.file.Read(p) } func (j *JsonFileHotReloadStorage) Write(p []byte) (n int, err error) { - file, err := os.Create(j.FileName) + j.file, err = os.Create(j.FileName) if err != nil { return 0, err } - defer file.Close() - return file.Write(p) + return j.file.Write(p) } -func NewJsonFileHotReloadStorage(filename string) *JsonFileHotReloadStorage { +func (j *JsonFileHotReloadStorage) Close() error { + if j.file != nil { + return j.file.Close() + } + return nil +} + +func NewJsonFileHotReloadStorage(filename string) HotReloadStorage { return &JsonFileHotReloadStorage{FileName: filename} } + +var _ HotReloadStorage = &JsonFileHotReloadStorage{}