diff --git a/base_response.go b/base_response.go index def2db3..3dfcb45 100644 --- a/base_response.go +++ b/base_response.go @@ -1,7 +1,5 @@ package openwechat -import "errors" - type Ret int const ( @@ -28,5 +26,5 @@ func (b BaseResponse) Err() error { if b.Ok() { return nil } - return errors.New(b.Ret.String()) + return b.Ret } diff --git a/bot.go b/bot.go index 119ae94..f4d95af 100644 --- a/bot.go +++ b/bot.go @@ -73,31 +73,39 @@ func (b *Bot) GetCurrentUser() (*Self, error) { // Storage := NewJsonFileHotReloadStorage("Storage.json") // err := bot.HotLogin(Storage, true) // fmt.Println(err) -func (b *Bot) HotLogin(storage HotReloadStorage, retry ...bool) error { - b.hotReloadStorage = storage - - var err error - - // 如果load出错了,就执行正常登陆逻辑 - // 第一次没有数据load都会出错的 - item, err := NewHotReloadStorageItem(storage) - - if err != nil { +func (b *Bot) HotLogin(storage HotReloadStorage, retries ...bool) error { + err := b.hotLogin(storage) + // 判断是否为需要重新登录 + if errors.Is(err, ErrInvalidStorage) { return b.Login() } - - if err = b.hotLoginInit(item); err != nil { - return err - } - - // 如果webInit出错,则说明可能身份信息已经失效 - // 如果retry为True的话,则进行正常登陆 - if err = b.WebInit(); err != nil && (len(retry) > 0 && retry[0]) { - err = b.Login() + if err != nil { + if len(retries) > 0 && retries[0] { + retErr, ok := err.(Ret) + if !ok { + return err + } + if retErr == cookieInvalid { + return b.Login() + } + } } return err } +func (b *Bot) hotLogin(storage HotReloadStorage) error { + b.hotReloadStorage = storage + var item HotReloadStorageItem + err := json.NewDecoder(storage).Decode(&item) + if err != nil { + return err + } + if err = b.hotLoginInit(&item); err != nil { + return err + } + return b.WebInit() +} + // 热登陆初始化 func (b *Bot) hotLoginInit(item *HotReloadStorageItem) error { cookies := item.Cookies diff --git a/errors.go b/errors.go index 4176404..1fbd54d 100644 --- a/errors.go +++ b/errors.go @@ -4,8 +4,6 @@ import ( "errors" ) -var NetworkErr = errors.New("wechat network error") - func IsNetworkError(err error) bool { return errors.Is(err, NetworkErr) } @@ -19,5 +17,18 @@ func IgnoreNetworkError(errHandler func(err error)) func(error) { } } -// ErrForbidden 禁止当前账号登录 -var ErrForbidden = errors.New("login forbidden") +var ( + // ErrForbidden 禁止当前账号登录 + ErrForbidden = errors.New("login forbidden") + + // ErrInvalidStorage define invalid storage error + ErrInvalidStorage = errors.New("invalid storage") + + // NetworkErr define wechat network error + NetworkErr = errors.New("wechat network error") +) + +// Error impl error interface +func (r Ret) Error() string { + return r.String() +} diff --git a/stroage.go b/stroage.go index e641988..399c45a 100644 --- a/stroage.go +++ b/stroage.go @@ -1,8 +1,6 @@ package openwechat import ( - "encoding/json" - "errors" "io" "net/http" "os" @@ -26,16 +24,19 @@ type HotReloadStorageItem struct { // HotReloadStorage 热登陆存储接口 type HotReloadStorage io.ReadWriter -// JsonFileHotReloadStorage 实现HotReloadStorage接口 +// jsonFileHotReloadStorage 实现HotReloadStorage接口 // 默认以json文件的形式存储 -type JsonFileHotReloadStorage struct { - FileName string +type jsonFileHotReloadStorage struct { + filename string file *os.File } -func (j *JsonFileHotReloadStorage) Read(p []byte) (n int, err error) { +func (j *jsonFileHotReloadStorage) Read(p []byte) (n int, err error) { if j.file == nil { - j.file, err = os.Open(j.FileName) + j.file, err = os.Open(j.filename) + if os.IsNotExist(err) { + return 0, ErrInvalidStorage + } if err != nil { return 0, err } @@ -43,9 +44,9 @@ func (j *JsonFileHotReloadStorage) Read(p []byte) (n int, err error) { return j.file.Read(p) } -func (j *JsonFileHotReloadStorage) Write(p []byte) (n int, err error) { +func (j *jsonFileHotReloadStorage) Write(p []byte) (n int, err error) { if j.file == nil { - j.file, err = os.Create(j.FileName) + j.file, err = os.Create(j.filename) if err != nil { return 0, err } @@ -53,21 +54,16 @@ func (j *JsonFileHotReloadStorage) Write(p []byte) (n int, err error) { return j.file.Write(p) } +func (j *jsonFileHotReloadStorage) Close() error { + if j.file == nil { + return nil + } + return j.file.Close() +} + // NewJsonFileHotReloadStorage 创建JsonFileHotReloadStorage -func NewJsonFileHotReloadStorage(filename string) HotReloadStorage { - return &JsonFileHotReloadStorage{FileName: filename} +func NewJsonFileHotReloadStorage(filename string) io.ReadWriteCloser { + return &jsonFileHotReloadStorage{filename: filename} } -var _ HotReloadStorage = (*JsonFileHotReloadStorage)(nil) - -func NewHotReloadStorageItem(storage HotReloadStorage) (*HotReloadStorageItem, error) { - if storage == nil { - return nil, errors.New("storage can't be nil") - } - var item HotReloadStorageItem - - if err := json.NewDecoder(storage).Decode(&item); err != nil { - return nil, err - } - return &item, nil -} +var _ HotReloadStorage = (*jsonFileHotReloadStorage)(nil)