diff --git a/bot.go b/bot.go index e599dcd..3391933 100644 --- a/bot.go +++ b/bot.go @@ -12,7 +12,7 @@ type Bot struct { LoginCallBack func(body []byte) // 登陆回调 LogoutCallBack func(bot *Bot) // 退出回调 UUIDCallback func(uuid string) // 获取UUID的回调函数 - MessageHandler func(msg *Message) // 获取消息成功的handle + MessageHandler MessageHandler // 获取消息成功的handle GetMessageErrorHandler func(err error) // 获取消息发生错误的handle isHot bool err error diff --git a/bot_test.go b/bot_test.go index fd15bc8..a5f0a08 100644 --- a/bot_test.go +++ b/bot_test.go @@ -343,3 +343,34 @@ func TestForwardMessage(t *testing.T) { } self.ForwardMessageToGroups(sentM, groups...) } + +func TestMessageMatchDispatcher(t *testing.T) { + dispatcher := NewMessageMatchDispatcher() + + m1 := func(ctx *MessageContext) { + if ctx.Content == "ping" { + ctx.ReplyText("pong") + } + ctx.Next() + t.Log("处理完毕~") + } + + m2 := func(ctx *MessageContext) { + if ctx.Content == "hello" { + ctx.ReplyText("world") + } + } + + dispatcher.OnText(m1, m2) + + dispatcher.OnFriendByRemarkName("1", func(ctx *MessageContext) { ctx.ReplyText("我收到了你的信息了") }) + + bot := defaultBot(Desktop) + bot.MessageHandler = DispatchMessage(dispatcher) + + if err := bot.Login(); err != nil { + t.Error(err) + return + } + bot.Block() +} diff --git a/message_handle.go b/message_handle.go new file mode 100644 index 0000000..0f245b9 --- /dev/null +++ b/message_handle.go @@ -0,0 +1,175 @@ +package openwechat + +// 消息处理函数 +type MessageHandler func(msg *Message) + +// 消息分发处理接口 +// 跟 DispatchMessage 结合封装成 MessageHandler +type MessageDispatcher interface { + Dispatch(msg *Message) +} + +// 跟 MessageDispatcher 结合封装成 MessageHandler +func DispatchMessage(dispatcher MessageDispatcher) func(msg *Message) { + return func(msg *Message) { dispatcher.Dispatch(msg) } +} + +// MessageDispatcher impl + +// MessageMatchDispatcher 消息处理函数 +type MessageContextHandler func(ctx *MessageContext) + +type MessageContextHandlerGroup []MessageContextHandler + +// MessageContext 消息处理上下文对象 +type MessageContext struct { + index int + messageHandlers MessageContextHandlerGroup + *Message +} + +// 主动调用下一个消息处理函数(或开始调用) +func (c *MessageContext) Next() { + c.index++ + for c.index <= len(c.messageHandlers) { + handle := c.messageHandlers[c.index-1] + handle(c) + c.index++ + } +} + +// 消息匹配函数,返回为true则表示匹配 +type matchFunc func(*Message) bool + +type matchNode struct { + matchFunc matchFunc + group MessageContextHandlerGroup +} + +type matchNodes []*matchNode + +// MessageMatchDispatcher impl MessageDispatcher interface +// dispatcher := NewMessageMatchDispatcher() +// dispatcher.OnText(func(msg *Message){ +// msg.ReplyText("hello") +// }) +// bot := DefaultBot() +// bot.MessageHandler = DispatchMessage(dispatcher) +type MessageMatchDispatcher struct { + async bool + matchNodes matchNodes +} + +// Constructor for MessageMatchDispatcher +func NewMessageMatchDispatcher() *MessageMatchDispatcher { + return &MessageMatchDispatcher{} +} + +// 设置是否异步处理 +func (m *MessageMatchDispatcher) SetAsync(async bool) { + m.async = async +} + +// Dispatch impl MessageDispatcher +// 遍历 MessageMatchDispatcher 所有的消息处理函数 +// 获取所有匹配上的函数 +// 执行处理的消息处理方法 +func (m *MessageMatchDispatcher) Dispatch(msg *Message) { + var group MessageContextHandlerGroup + for _, node := range m.matchNodes { + if node.matchFunc(msg) { + group = append(group, node.group...) + } + } + ctx := &MessageContext{Message: msg, messageHandlers: group} + if m.async { + go m.do(ctx) + } else { + m.do(ctx) + } +} + +func (m *MessageMatchDispatcher) do(ctx *MessageContext) { + ctx.Next() +} + +// 注册消息处理函数, 根据自己的需求自定义 +// matchFunc返回true则表示处理对应的handlers +func (m *MessageMatchDispatcher) RegisterHandler(matchFunc matchFunc, handlers ...MessageContextHandler) { + if matchFunc == nil { + panic("matchFunc can not be nil") + } + node := &matchNode{matchFunc: matchFunc, group: handlers} + m.matchNodes = append(m.matchNodes, node) +} + +// 注册处理消息类型为Text的处理函数 +func (m *MessageMatchDispatcher) OnText(handlers ...MessageContextHandler) { + m.RegisterHandler(func(message *Message) bool { + return message.IsText() + }, handlers...) +} + +// 注册处理消息类型为Image的处理函数 +func (m *MessageMatchDispatcher) OnImage(handlers ...MessageContextHandler) { + m.RegisterHandler(func(message *Message) bool { + return message.IsPicture() + }, handlers...) +} + +// 注册处理消息类型为Voice的处理函数 +func (m *MessageMatchDispatcher) OnVoice(handlers ...MessageContextHandler) { + m.RegisterHandler(func(message *Message) bool { + return message.IsVoice() + }, handlers...) +} + +// 注册处理消息类型为FriendAdd的处理函数 +func (m *MessageMatchDispatcher) OnFriendAdd(handlers ...MessageContextHandler) { + m.RegisterHandler(func(message *Message) bool { + return message.IsFriendAdd() + }, handlers...) +} + +// 注册处理消息类型为Card的处理函数 +func (m *MessageMatchDispatcher) OnCard(handlers ...MessageContextHandler) { + m.RegisterHandler(func(message *Message) bool { + return message.IsCard() + }, handlers...) +} + +// 注册根据好友昵称是否匹配的消息处理函数 +func (m *MessageMatchDispatcher) OnFriendByNickName(nickName string, handlers ...MessageContextHandler) { + matchFunc := func(message *Message) bool { + if message.IsSendByFriend() { + sender, err := message.Sender() + return err == nil && sender.NickName == nickName + } + return false + } + m.RegisterHandler(matchFunc, handlers...) +} + +// 注册根据好友备注是否匹配的消息处理函数 +func (m *MessageMatchDispatcher) OnFriendByRemarkName(remarkName string, handlers ...MessageContextHandler) { + f := func(message *Message) bool { + if message.IsSendByFriend() { + sender, err := message.Sender() + return err == nil && sender.RemarkName == remarkName + } + return false + } + m.RegisterHandler(f, handlers...) +} + +// 注册根据群名是否匹配的消息处理函数 +func (m *MessageMatchDispatcher) OnGroupByGroupName(groupName string, handlers ...MessageContextHandler) { + f := func(message *Message) bool { + if message.IsSendByGroup() { + sender, err := message.Sender() + return err == nil && sender.NickName == groupName + } + return false + } + m.RegisterHandler(f, handlers...) +}