Compare commits

...

14 Commits

Author SHA1 Message Date
多吃点苹果
3e6fc41298
[fix]: 修改同步消息逻辑 (#236) 2023-02-05 11:23:05 +08:00
ford
3c7ac0cc75
【opt】优化好友公众号群组获取接口防止频繁发送网络请求 (#234)
Co-authored-by: wenyoufu <wenyoufu@jd.com>
2023-02-05 11:20:40 +08:00
多吃点苹果
99af4a2685
[refactor]: 添加 CookieGroup (#233) 2023-02-05 00:20:37 +08:00
多吃点苹果
0c57ab1ed5
更新 Group Display (#232) 2023-02-04 23:55:55 +08:00
多吃点苹果
a72c165c59
删除根据备注查找群组功能 (#231) 2023-02-04 12:04:36 +08:00
多吃点苹果
35a348f0af
[feat]: 添加最近联系人和公众号文章列表 (#230) 2023-02-04 11:59:17 +08:00
多吃点苹果
66c4bebd1f
提升上传文件性能 (#228) 2023-02-03 22:17:40 +08:00
多吃点苹果
fbfd691cb4
[style]: update User display (#227) 2023-02-03 17:57:58 +08:00
多吃点苹果
5194ad4965
[style]: 移除 DispatchMessage (#224) 2023-02-02 10:24:18 +08:00
多吃点苹果
eccc25e66e
[style]: Deprecated NewJsonFileHotReloadStorage (#223) 2023-02-02 10:06:17 +08:00
多吃点苹果
d77bb0a4cb
[feat]: 支持用户自定义热存储数据的序列化和反序列化 (#222) 2023-02-02 00:15:46 +08:00
多吃点苹果
e9c89f9ac8
[style]: 支持扫码登录自定义uuid (#221) 2023-02-02 00:05:26 +08:00
多吃点苹果
6629e77fd5
[feat]: 支持自定义添加 context 用于控制 bot 存活 (#220) 2023-02-01 23:54:11 +08:00
多吃点苹果
76bd0a5648
[fix]: 解决定时器同步数据到热存储中的数据竞争问题 https://github.com/eatmoreapple/openwech… (#219) 2023-02-01 23:43:10 +08:00
14 changed files with 196 additions and 90 deletions

38
bot.go
View File

@ -2,7 +2,6 @@ package openwechat
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"io" "io"
"log" "log"
@ -20,13 +19,14 @@ type Bot struct {
SyncCheckCallback func(resp SyncCheckResponse) // 心跳回调 SyncCheckCallback func(resp SyncCheckResponse) // 心跳回调
MessageHandler MessageHandler // 获取消息成功的handle MessageHandler MessageHandler // 获取消息成功的handle
MessageErrorHandler func(err error) bool // 获取消息发生错误的handle, 返回true则尝试继续监听 MessageErrorHandler func(err error) bool // 获取消息发生错误的handle, 返回true则尝试继续监听
Serializer Serializer // 序列化器, 默认为json
Storage *Storage
Caller *Caller
once sync.Once once sync.Once
err error err error
context context.Context context context.Context
cancel context.CancelFunc cancel context.CancelFunc
Caller *Caller
self *Self self *Self
Storage *Storage
hotReloadStorage HotReloadStorage hotReloadStorage HotReloadStorage
uuid string uuid string
loginUUID *string loginUUID *string
@ -51,6 +51,7 @@ func (b *Bot) Alive() bool {
// @description: 设置设备Id // @description: 设置设备Id
// @receiver b // @receiver b
// @param deviceId // @param deviceId
// TODO ADD INTO LOGIN OPTION
func (b *Bot) SetDeviceId(deviceId string) { func (b *Bot) SetDeviceId(deviceId string) {
b.deviceId = deviceId b.deviceId = deviceId
} }
@ -84,7 +85,7 @@ func (b *Bot) login(login BotLogin) (err error) {
// Login 用户登录 // Login 用户登录
func (b *Bot) Login() error { func (b *Bot) Login() error {
scanLogin := &SacnLogin{} scanLogin := &SacnLogin{UUID: b.loginUUID}
return b.login(scanLogin) return b.login(scanLogin)
} }
@ -168,9 +169,10 @@ func (b *Bot) WebInit() error {
return err return err
} }
// 设置当前的用户 // 设置当前的用户
b.self = &Self{bot: b, User: &resp.User} b.self = &Self{bot: b, User: resp.User}
b.self.formatEmoji() b.self.formatEmoji()
b.self.self = b.self b.self.self = b.self
resp.ContactList.init(b.self)
b.Storage.Response = resp b.Storage.Response = resp
// 通知手机客户端已经登录 // 通知手机客户端已经登录
@ -221,8 +223,8 @@ func (b *Bot) syncCheck() error {
if !resp.Success() { if !resp.Success() {
return resp.Err() return resp.Err()
} }
// 如果Selector不为0则获取消息 switch resp.Selector {
if !resp.NorMal() { case SelectorNewMsg:
messages, err := b.syncMessage() messages, err := b.syncMessage()
if err != nil { if err != nil {
return err return err
@ -235,8 +237,12 @@ func (b *Bot) syncCheck() error {
// 默认同步调用 // 默认同步调用
// 如果异步调用则需自行处理 // 如果异步调用则需自行处理
// 如配合 openwechat.MessageMatchDispatcher 使用 // 如配合 openwechat.MessageMatchDispatcher 使用
// NOTE: 请确保 MessageHandler 不会阻塞,否则会导致收不到后续的消息
b.MessageHandler(message) b.MessageHandler(message)
} }
case SelectorModContact:
case SelectorAddOrDelContact:
case SelectorModChatRoom:
} }
} }
return err return err
@ -295,7 +301,7 @@ func (b *Bot) DumpTo(writer io.Writer) error {
WechatDomain: b.Caller.Client.Domain, WechatDomain: b.Caller.Client.Domain,
UUID: b.uuid, UUID: b.uuid,
} }
return json.NewEncoder(writer).Encode(item) return b.Serializer.Encode(writer, item)
} }
// IsHot returns true if is hot login otherwise false // IsHot returns true if is hot login otherwise false
@ -303,7 +309,7 @@ func (b *Bot) IsHot() bool {
return b.hotReloadStorage != nil return b.hotReloadStorage != nil
} }
// UUID returns current uuid of bot // UUID returns current UUID of bot
func (b *Bot) UUID() string { func (b *Bot) UUID() string {
return b.uuid return b.uuid
} }
@ -311,7 +317,8 @@ func (b *Bot) UUID() string {
// SetUUID // SetUUID
// @description: 设置UUID可以用来手动登录用 // @description: 设置UUID可以用来手动登录用
// @receiver b // @receiver b
// @param uuid // @param UUID
// TODO ADD INTO LOGIN OPTION
func (b *Bot) SetUUID(uuid string) { func (b *Bot) SetUUID(uuid string) {
b.loginUUID = &uuid b.loginUUID = &uuid
} }
@ -326,8 +333,7 @@ func (b *Bot) reload() error {
return errors.New("hotReloadStorage is nil") return errors.New("hotReloadStorage is nil")
} }
var item HotReloadStorageItem var item HotReloadStorageItem
err := json.NewDecoder(b.hotReloadStorage).Decode(&item) if err := b.Serializer.Decode(b.hotReloadStorage, &item); err != nil {
if err != nil {
return err return err
} }
b.Caller.Client.SetCookieJar(item.Jar) b.Caller.Client.SetCookieJar(item.Jar)
@ -345,7 +351,13 @@ func NewBot(c context.Context) *Bot {
// 默认行为为网页版微信模式 // 默认行为为网页版微信模式
caller.Client.SetMode(normal) caller.Client.SetMode(normal)
ctx, cancel := context.WithCancel(c) ctx, cancel := context.WithCancel(c)
return &Bot{Caller: caller, Storage: &Storage{}, context: ctx, cancel: cancel} return &Bot{
Caller: caller,
Storage: &Storage{},
Serializer: &JsonSerializer{},
context: ctx,
cancel: cancel,
}
} }
// DefaultBot 默认的Bot的构造方法, // DefaultBot 默认的Bot的构造方法,

View File

@ -1,6 +1,7 @@
package openwechat package openwechat
import ( import (
"context"
"time" "time"
) )
@ -132,16 +133,16 @@ func NewSyncReloadDataLoginOption(duration time.Duration) BotLoginOption {
return &SyncReloadDataLoginOption{SyncLoopDuration: duration} return &SyncReloadDataLoginOption{SyncLoopDuration: duration}
} }
// WithModeOption 指定使用哪种客户端模式 // withModeOption 指定使用哪种客户端模式
type WithModeOption struct { type withModeOption struct {
mode Mode mode Mode
} }
// Prepare 实现了 BotLoginOption 接口 // Prepare 实现了 BotLoginOption 接口
func (w WithModeOption) Prepare(b *Bot) { b.Caller.Client.SetMode(w.mode) } func (w withModeOption) Prepare(b *Bot) { b.Caller.Client.SetMode(w.mode) }
func withMode(mode Mode) BotPreparer { func withMode(mode Mode) BotPreparer {
return WithModeOption{mode: mode} return withModeOption{mode: mode}
} }
// btw, 这两个变量已经变了4回了, 但是为了兼容以前的代码, 还是得想着法儿让用户无感知的更新 // btw, 这两个变量已经变了4回了, 但是为了兼容以前的代码, 还是得想着法儿让用户无感知的更新
@ -153,6 +154,19 @@ var (
Desktop = withMode(desktop) Desktop = withMode(desktop)
) )
// WithContextOption 指定一个 context.Context 用于Bot的生命周期
type WithContextOption struct {
Ctx context.Context
}
// Prepare 实现了 BotLoginOption 接口
func (w WithContextOption) Prepare(b *Bot) {
if w.Ctx == nil {
panic("context is nil")
}
b.context, b.cancel = context.WithCancel(w.Ctx)
}
const ( const (
defaultHotStorageSyncDuration = time.Minute * 5 defaultHotStorageSyncDuration = time.Minute * 5
) )
@ -163,19 +177,21 @@ type BotLogin interface {
} }
// SacnLogin 扫码登录 // SacnLogin 扫码登录
type SacnLogin struct{} type SacnLogin struct {
UUID *string
}
// Login 实现了 BotLogin 接口 // Login 实现了 BotLogin 接口
func (s *SacnLogin) Login(bot *Bot) error { func (s *SacnLogin) Login(bot *Bot) error {
var uuid string var uuid string
if bot.loginUUID == nil { if s.UUID == nil {
var err error var err error
uuid, err = bot.Caller.GetLoginUUID() uuid, err = bot.Caller.GetLoginUUID()
if err != nil { if err != nil {
return err return err
} }
} else { } else {
uuid = *bot.loginUUID uuid = *s.UUID
} }
return s.checkLogin(bot, uuid) return s.checkLogin(bot, uuid)
} }

View File

@ -314,6 +314,8 @@ func (c *Client) WebWxGetHeadImg(user *User) (*http.Response, error) {
return c.Do(req) return c.Do(req)
} }
// WebWxUploadMediaByChunk 分块上传文件
// TODO 优化掉这个函数
func (c *Client) WebWxUploadMediaByChunk(file *os.File, request *BaseRequest, info *LoginInfo, forUserName, toUserName string) (*http.Response, error) { func (c *Client) WebWxUploadMediaByChunk(file *os.File, request *BaseRequest, info *LoginInfo, forUserName, toUserName string) (*http.Response, error) {
// 获取文件上传的类型 // 获取文件上传的类型
contentType, err := GetFileContentType(file) contentType, err := GetFileContentType(file)
@ -358,7 +360,11 @@ func (c *Client) WebWxUploadMediaByChunk(file *os.File, request *BaseRequest, in
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
cookies := c.Jar().Cookies(path) cookies := c.Jar().Cookies(path)
webWxDataTicket := getWebWxDataTicket(cookies)
webWxDataTicket, err := getWebWxDataTicket(cookies)
if err != nil {
return nil, err
}
uploadMediaRequest := map[string]interface{}{ uploadMediaRequest := map[string]interface{}{
"UploadType": 2, "UploadType": 2,
@ -410,16 +416,17 @@ func (c *Client) WebWxUploadMediaByChunk(file *os.File, request *BaseRequest, in
return nil, err return nil, err
} }
var chunkBuff = make([]byte, chunkSize)
var formBuffer = bytes.NewBuffer(nil)
// 分块上传 // 分块上传
for chunk := 0; int64(chunk) < chunks; chunk++ { for chunk := 0; int64(chunk) < chunks; chunk++ {
isLastTime := int64(chunk)+1 == chunks
if chunks > 1 { if chunks > 1 {
content["chunk"] = strconv.Itoa(chunk) content["chunk"] = strconv.Itoa(chunk)
} }
var formBuffer = bytes.NewBuffer(nil) formBuffer.Reset()
writer := multipart.NewWriter(formBuffer) writer := multipart.NewWriter(formBuffer)
@ -434,34 +441,33 @@ func (c *Client) WebWxUploadMediaByChunk(file *os.File, request *BaseRequest, in
} }
w, err := writer.CreateFormFile("filename", file.Name()) w, err := writer.CreateFormFile("filename", file.Name())
if err != nil { if err != nil {
return nil, err return nil, err
} }
chunkData := make([]byte, chunkSize) n, err := file.Read(chunkBuff)
n, err := file.Read(chunkData)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return nil, err return nil, err
} }
if _, err = w.Write(chunkBuff[:n]); err != nil {
if _, err = w.Write(chunkData[:n]); err != nil {
return nil, err return nil, err
} }
ct := writer.FormDataContentType() ct := writer.FormDataContentType()
if err = writer.Close(); err != nil { if err = writer.Close(); err != nil {
return nil, err return nil, err
} }
req, _ := http.NewRequest(http.MethodPost, path.String(), formBuffer) req, _ := http.NewRequest(http.MethodPost, path.String(), formBuffer)
req.Header.Set("Content-Type", ct) req.Header.Set("Content-Type", ct)
// 发送数据 // 发送数据
resp, err = c.Do(req) resp, err = c.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
isLastTime := int64(chunk)+1 == chunks
// 如果不是最后一次, 解析有没有错误 // 如果不是最后一次, 解析有没有错误
if !isLastTime { if !isLastTime {
parser := MessageResponseParser{Reader: resp.Body} parser := MessageResponseParser{Reader: resp.Body}
@ -591,13 +597,18 @@ func (c *Client) WebWxGetVideo(msg *Message, info *LoginInfo) (*http.Response, e
// WebWxGetMedia 获取文件消息的文件响应 // WebWxGetMedia 获取文件消息的文件响应
func (c *Client) WebWxGetMedia(msg *Message, info *LoginInfo) (*http.Response, error) { func (c *Client) WebWxGetMedia(msg *Message, info *LoginInfo) (*http.Response, error) {
path, _ := url.Parse(c.Domain.FileHost() + webwxgetmedia) path, _ := url.Parse(c.Domain.FileHost() + webwxgetmedia)
cookies := c.Jar().Cookies(path)
webWxDataTicket, err := getWebWxDataTicket(cookies)
if err != nil {
return nil, err
}
params := url.Values{} params := url.Values{}
params.Add("sender", msg.FromUserName) params.Add("sender", msg.FromUserName)
params.Add("mediaid", msg.MediaId) params.Add("mediaid", msg.MediaId)
params.Add("encryfilename", msg.EncryFileName) params.Add("encryfilename", msg.EncryFileName)
params.Add("fromuser", strconv.FormatInt(info.WxUin, 10)) params.Add("fromuser", strconv.FormatInt(info.WxUin, 10))
params.Add("pass_ticket", info.PassTicket) params.Add("pass_ticket", info.PassTicket)
params.Add("webwx_data_ticket", getWebWxDataTicket(c.Jar().Cookies(path))) params.Add("webwx_data_ticket", webWxDataTicket)
path.RawQuery = params.Encode() path.RawQuery = params.Encode()
req, _ := http.NewRequest(http.MethodGet, path.String(), nil) req, _ := http.NewRequest(http.MethodGet, path.String(), nil)
req.Header.Add("Referer", c.Domain.BaseHost()+"/") req.Header.Add("Referer", c.Domain.BaseHost()+"/")

View File

@ -9,6 +9,7 @@ import (
) )
// Jar is a struct which as same as cookiejar.Jar // Jar is a struct which as same as cookiejar.Jar
// cookiejar.Jar's fields are private, so we can't use it directly
type Jar struct { type Jar struct {
PsList cookiejar.PublicSuffixList PsList cookiejar.PublicSuffixList
@ -57,3 +58,24 @@ type entry struct {
// equal Creation time. This simplifies testing. // equal Creation time. This simplifies testing.
seqNum uint64 seqNum uint64
} }
// CookieGroup is a group of cookies
type CookieGroup []*http.Cookie
func (c CookieGroup) GetByName(cookieName string) (cookie *http.Cookie, exist bool) {
for _, cookie := range c {
if cookie.Name == cookieName {
return cookie, true
}
}
return nil, false
}
func getWebWxDataTicket(cookies []*http.Cookie) (string, error) {
cookieGroup := CookieGroup(cookies)
cookie, exist := cookieGroup.GetByName("webwx_data_ticket")
if !exist {
return "", ErrWebWxDataTicketNotFound
}
return cookie.Value, nil
}

View File

@ -57,9 +57,9 @@ type WebInitResponse struct {
SKey string SKey string
BaseResponse BaseResponse BaseResponse BaseResponse
SyncKey SyncKey SyncKey SyncKey
User User User *User
MPSubscribeMsgList []MPSubscribeMsg MPSubscribeMsgList []*MPSubscribeMsg
ContactList []User ContactList Members
} }
// MPSubscribeMsg 公众号的订阅信息 // MPSubscribeMsg 公众号的订阅信息
@ -68,12 +68,14 @@ type MPSubscribeMsg struct {
Time int64 Time int64
UserName string UserName string
NickName string NickName string
MPArticleList []struct { MPArticleList []*MPArticle
Title string }
Cover string
Digest string type MPArticle struct {
Url string Title string
} Cover string
Digest string
Url string
} }
type UserDetailItem struct { type UserDetailItem struct {

View File

@ -32,6 +32,9 @@ var (
// ErrLoginTimeout define login timeout error // ErrLoginTimeout define login timeout error
ErrLoginTimeout = errors.New("login timeout") ErrLoginTimeout = errors.New("login timeout")
// ErrWebWxDataTicketNotFound define webwx_data_ticket not found error
ErrWebWxDataTicketNotFound = errors.New("webwx_data_ticket not found")
) )
// Error impl error interface // Error impl error interface

View File

@ -11,14 +11,6 @@ type MessageDispatcher interface {
Dispatch(msg *Message) Dispatch(msg *Message)
} }
// DispatchMessage 跟 MessageDispatcher 结合封装成 MessageHandler
// Deprecated: use MessageMatchDispatcher.AsMessageHandler instead
func DispatchMessage(dispatcher MessageDispatcher) func(msg *Message) {
return func(msg *Message) { dispatcher.Dispatch(msg) }
}
// MessageDispatcher impl
// MessageContextHandler 消息处理函数 // MessageContextHandler 消息处理函数
type MessageContextHandler func(ctx *MessageContext) type MessageContextHandler func(ctx *MessageContext)

View File

@ -38,15 +38,6 @@ func GetRandomDeviceId() string {
return builder.String() return builder.String()
} }
func getWebWxDataTicket(cookies []*http.Cookie) string {
for _, cookie := range cookies {
if cookie.Name == "webwx_data_ticket" {
return cookie.Value
}
}
return ""
}
// GetFileContentType 获取文件上传的类型 // GetFileContentType 获取文件上传的类型
func GetFileContentType(file multipart.File) (string, error) { func GetFileContentType(file multipart.File) (string, error) {
data := make([]byte, 512) data := make([]byte, 512)

View File

@ -10,7 +10,11 @@ type Friend struct{ *User }
// implement fmt.Stringer // implement fmt.Stringer
func (f *Friend) String() string { func (f *Friend) String() string {
return fmt.Sprintf("<Friend:%s>", f.NickName) display := f.NickName
if f.RemarkName != "" {
display = f.RemarkName
}
return fmt.Sprintf("<Friend:%s>", display)
} }
// SetRemarkName 重命名当前好友 // SetRemarkName 重命名当前好友
@ -300,11 +304,6 @@ func (g Groups) SearchByNickName(limit int, nickName string) (results Groups) {
return g.Search(limit, func(group *Group) bool { return group.NickName == nickName }) return g.Search(limit, func(group *Group) bool { return group.NickName == nickName })
} }
// SearchByRemarkName 根据备注查找群组
func (g Groups) SearchByRemarkName(limit int, remarkName string) (results Groups) {
return g.Search(limit, func(group *Group) bool { return group.RemarkName == remarkName })
}
// Search 根据自定义条件查找群组 // Search 根据自定义条件查找群组
func (g Groups) Search(limit int, searchFuncList ...func(group *Group) bool) (results Groups) { func (g Groups) Search(limit int, searchFuncList ...func(group *Group) bool) (results Groups) {
return g.AsMembers().Search(limit, func(user *User) bool { return g.AsMembers().Search(limit, func(user *User) bool {
@ -445,11 +444,6 @@ func (g Groups) GetByUsername(username string) *Group {
return g.SearchByUserName(1, username).First() return g.SearchByUserName(1, username).First()
} }
// GetByRemarkName 根据remarkName查询一个Group
func (g Groups) GetByRemarkName(remarkName string) *Group {
return g.SearchByRemarkName(1, remarkName).First()
}
// GetByNickName 根据nickname查询一个Group // GetByNickName 根据nickname查询一个Group
func (g Groups) GetByNickName(nickname string) *Group { func (g Groups) GetByNickName(nickname string) *Group {
return g.SearchByNickName(1, nickname).First() return g.SearchByNickName(1, nickname).First()

25
serializer.go Normal file
View File

@ -0,0 +1,25 @@
package openwechat
import (
"encoding/json"
"io"
)
// Serializer is an interface for encoding and decoding data.
type Serializer interface {
Encode(writer io.Writer, v interface{}) error
Decode(reader io.Reader, v interface{}) error
}
// JsonSerializer is a serializer for json.
type JsonSerializer struct{}
// Encode encodes v to writer.
func (j JsonSerializer) Encode(writer io.Writer, v interface{}) error {
return json.NewEncoder(writer).Encode(v)
}
// Decode decodes data from reader to v.
func (j JsonSerializer) Decode(reader io.Reader, v interface{}) error {
return json.NewDecoder(reader).Decode(v)
}

View File

@ -241,7 +241,7 @@ dispatcher.OnText(func(ctx *openwechat.MessageContext){
}) })
// 注册消息回调函数 // 注册消息回调函数
bot.MessageHandler = openwechat.DispatchMessage(dispatcher) bot.MessageHandler = dispatcher.AsMessageHandler()
``` ```
`openwechat.DispatchMessage`会将消息转发给`dispatcher`对象处理 `openwechat.DispatchMessage`会将消息转发给`dispatcher`对象处理

View File

@ -3,6 +3,7 @@ package openwechat
import ( import (
"io" "io"
"os" "os"
"sync"
"time" "time"
) )
@ -24,14 +25,17 @@ type HotReloadStorageItem struct {
// HotReloadStorage 热登陆存储接口 // HotReloadStorage 热登陆存储接口
type HotReloadStorage io.ReadWriter type HotReloadStorage io.ReadWriter
// jsonFileHotReloadStorage 实现HotReloadStorage接口 // fileHotReloadStorage 实现HotReloadStorage接口
// 默认json文件的形式存储 // 以文件的形式存储
type jsonFileHotReloadStorage struct { type fileHotReloadStorage struct {
filename string filename string
file *os.File file *os.File
lock sync.Mutex
} }
func (j *jsonFileHotReloadStorage) Read(p []byte) (n int, err error) { func (j *fileHotReloadStorage) Read(p []byte) (n int, err error) {
j.lock.Lock()
defer j.lock.Unlock()
if j.file == nil { if j.file == nil {
j.file, err = os.OpenFile(j.filename, os.O_RDWR, 0600) j.file, err = os.OpenFile(j.filename, os.O_RDWR, 0600)
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -44,38 +48,47 @@ func (j *jsonFileHotReloadStorage) Read(p []byte) (n int, err error) {
return j.file.Read(p) return j.file.Read(p)
} }
func (j *jsonFileHotReloadStorage) Write(p []byte) (n int, err error) { func (j *fileHotReloadStorage) Write(p []byte) (n int, err error) {
j.lock.Lock()
defer j.lock.Unlock()
if j.file == nil { if j.file == nil {
j.file, err = os.Create(j.filename) j.file, err = os.Create(j.filename)
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
// 为什么这里要对文件进行Truncate操作呢? // reset offset and truncate file
// 这是为了方便每次Dump的时候对文件进行重新写入, 而不是追加
// json序列化写入只会调用一次Write方法, 所以不要把这个方法当成io.Writer的Write方法
if _, err = j.file.Seek(0, io.SeekStart); err != nil { if _, err = j.file.Seek(0, io.SeekStart); err != nil {
return return
} }
if err = j.file.Truncate(0); err != nil { if err = j.file.Truncate(0); err != nil {
return return
} }
// json decode only write once
return j.file.Write(p) return j.file.Write(p)
} }
func (j *jsonFileHotReloadStorage) Close() error { func (j *fileHotReloadStorage) Close() error {
j.lock.Lock()
defer j.lock.Unlock()
if j.file == nil { if j.file == nil {
return nil return nil
} }
return j.file.Close() return j.file.Close()
} }
// NewJsonFileHotReloadStorage 创建JsonFileHotReloadStorage // Deprecated: use NewFileHotReloadStorage instead
// 不再单纯以json的格式存储支持了用户自定义序列化方式
func NewJsonFileHotReloadStorage(filename string) io.ReadWriteCloser { func NewJsonFileHotReloadStorage(filename string) io.ReadWriteCloser {
return &jsonFileHotReloadStorage{filename: filename} return NewFileHotReloadStorage(filename)
} }
var _ HotReloadStorage = (*jsonFileHotReloadStorage)(nil) // NewFileHotReloadStorage implements HotReloadStorage
func NewFileHotReloadStorage(filename string) io.ReadWriteCloser {
return &fileHotReloadStorage{filename: filename}
}
var _ HotReloadStorage = (*fileHotReloadStorage)(nil)
type HotReloadStorageSyncer struct { type HotReloadStorageSyncer struct {
duration time.Duration duration time.Duration

View File

@ -25,7 +25,11 @@ func (s SyncCheckResponse) Success() bool {
} }
func (s SyncCheckResponse) NorMal() bool { func (s SyncCheckResponse) NorMal() bool {
return s.Success() && s.Selector == "0" return s.Success() && s.Selector == SelectorNormal
}
func (s SyncCheckResponse) HasNewMessage() bool {
return s.Success() && s.Selector == SelectorNewMsg
} }
func (s SyncCheckResponse) Err() error { func (s SyncCheckResponse) Err() error {

33
user.go
View File

@ -58,14 +58,14 @@ type User struct {
// implement fmt.Stringer // implement fmt.Stringer
func (u *User) String() string { func (u *User) String() string {
format := "User" format := "User"
if u.IsFriend() { if u.IsSelf() {
format = "Self"
} else if u.IsFriend() {
format = "Friend" format = "Friend"
} else if u.IsGroup() { } else if u.IsGroup() {
format = "Group" format = "Group"
} else if u.IsMP() { } else if u.IsMP() {
format = "MP" format = "MP"
} else if u.IsSelf() {
format = "Self"
} }
return fmt.Sprintf("<%s:%s>", format, u.NickName) return fmt.Sprintf("<%s:%s>", format, u.NickName)
} }
@ -288,13 +288,18 @@ func (s *Self) FileHelper() *Friend {
} }
return s.fileHelper return s.fileHelper
} }
func (s *Self) ChkFrdGrpMpNil() bool {
return s.friends == nil && s.groups == nil && s.mps == nil
}
// Friends 获取所有的好友 // Friends 获取所有的好友
func (s *Self) Friends(update ...bool) (Friends, error) { func (s *Self) Friends(update ...bool) (Friends, error) {
if s.friends == nil || (len(update) > 0 && update[0]) { if (len(update) > 0 && update[0]) || s.ChkFrdGrpMpNil() {
if _, err := s.Members(true); err != nil { if _, err := s.Members(true); err != nil {
return nil, err return nil, err
} }
}
if s.friends == nil || (len(update) > 0 && update[0]) {
s.friends = s.members.Friends() s.friends = s.members.Friends()
} }
return s.friends, nil return s.friends, nil
@ -302,10 +307,14 @@ func (s *Self) Friends(update ...bool) (Friends, error) {
// Groups 获取所有的群组 // Groups 获取所有的群组
func (s *Self) Groups(update ...bool) (Groups, error) { func (s *Self) Groups(update ...bool) (Groups, error) {
if s.groups == nil || (len(update) > 0 && update[0]) {
if (len(update) > 0 && update[0]) || s.ChkFrdGrpMpNil() {
if _, err := s.Members(true); err != nil { if _, err := s.Members(true); err != nil {
return nil, err return nil, err
} }
}
if s.groups == nil || (len(update) > 0 && update[0]) {
s.groups = s.members.Groups() s.groups = s.members.Groups()
} }
return s.groups, nil return s.groups, nil
@ -313,10 +322,12 @@ func (s *Self) Groups(update ...bool) (Groups, error) {
// Mps 获取所有的公众号 // Mps 获取所有的公众号
func (s *Self) Mps(update ...bool) (Mps, error) { func (s *Self) Mps(update ...bool) (Mps, error) {
if s.mps == nil || (len(update) > 0 && update[0]) { if (len(update) > 0 && update[0]) || s.ChkFrdGrpMpNil() {
if _, err := s.Members(true); err != nil { if _, err := s.Members(true); err != nil {
return nil, err return nil, err
} }
}
if s.mps == nil || (len(update) > 0 && update[0]) {
s.mps = s.members.MPs() s.mps = s.members.MPs()
} }
return s.mps, nil return s.mps, nil
@ -668,6 +679,16 @@ func (s *Self) SendVideoToGroups(video io.Reader, delay time.Duration, groups ..
return s.sendVideoToMembers(video, delay, members...) return s.sendVideoToMembers(video, delay, members...)
} }
// ContactList 获取最近的联系人列表
func (s *Self) ContactList() Members {
return s.Bot().Storage.Response.ContactList
}
// MPSubscribeList 获取部分公众号文章列表
func (s *Self) MPSubscribeList() []*MPSubscribeMsg {
return s.Bot().Storage.Response.MPSubscribeMsgList
}
// Members 抽象的用户组 // Members 抽象的用户组
type Members []*User type Members []*User