From 6b7cc8fcdd4cd107b123f08febe345352a871ea6 Mon Sep 17 00:00:00 2001 From: Veda Maharaj <110698921+SoggySaussages@users.noreply.github.com> Date: Wed, 12 Jun 2024 07:12:24 -0700 Subject: [PATCH] customcommands: thread editing functions (#1665) * cc/threads: add edit thread functions Signed-off-by: SoggySaussages * customcommands: reorder thread functions Signed-off-by: SoggySaussages * cc/threads: fix editThread invitable Signed-off-by: SoggySaussages --------- Signed-off-by: SoggySaussages --- common/templates/context.go | 7 +- common/templates/context_funcs.go | 395 +++++++++++++++++++++++------- common/templates/structs.go | 12 +- lib/discordgo/structs.go | 45 +++- lib/dstate/helpers.go | 5 + lib/dstate/interface.go | 9 +- 6 files changed, 376 insertions(+), 97 deletions(-) diff --git a/common/templates/context.go b/common/templates/context.go index 634a9486f5..fc40e6005c 100644 --- a/common/templates/context.go +++ b/common/templates/context.go @@ -654,14 +654,19 @@ func baseContextFuncs(c *Context) { c.addContextFunc("getThread", c.tmplGetThread) // thread functions + c.addContextFunc("addThreadMember", c.tmplThreadMemberAdd) + c.addContextFunc("closeThread", c.tmplCloseThread) c.addContextFunc("createThread", c.tmplCreateThread) c.addContextFunc("deleteThread", c.tmplDeleteThread) - c.addContextFunc("addThreadMember", c.tmplThreadMemberAdd) + c.addContextFunc("editThread", c.tmplEditThread) + c.addContextFunc("openThread", c.tmplOpenThread) c.addContextFunc("removeThreadMember", c.tmplThreadMemberRemove) // forum functions c.addContextFunc("createForumPost", c.tmplCreateForumPost) c.addContextFunc("deleteForumPost", c.tmplDeleteThread) + c.addContextFunc("pinForumPost", c.tmplPinForumPost(false)) + c.addContextFunc("unpinForumPost", c.tmplPinForumPost(true)) c.addContextFunc("currentUserAgeHuman", c.tmplCurrentUserAgeHuman) c.addContextFunc("currentUserAgeMinutes", c.tmplCurrentUserAgeMinutes) diff --git a/common/templates/context_funcs.go b/common/templates/context_funcs.go index a30395c4c2..f315474e91 100644 --- a/common/templates/context_funcs.go +++ b/common/templates/context_funcs.go @@ -953,15 +953,73 @@ func (c *Context) tmplGetThread(channel interface{}) (*CtxChannel, error) { return CtxChannelFromCS(cstate), nil } -func (c *Context) AddThreadToGuildSet(t *dstate.ChannelState) { - // Perform a copy so we don't mutate global array - gsCopy := *c.GS - gsCopy.Threads = make([]dstate.ChannelState, len(c.GS.Threads), len(c.GS.Threads)+1) - copy(gsCopy.Threads, c.GS.Threads) +func (c *Context) tmplThreadMemberAdd(threadID, memberID interface{}) string { - // Add new thread to copied guild state - gsCopy.Threads = append(gsCopy.Threads, *t) - c.GS = &gsCopy + if c.IncreaseCheckGenericAPICall() { + return "" + } + + tID := c.ChannelArg(threadID) + if tID == 0 { + return "" + } + + cstate := c.GS.GetThread(tID) + if cstate == nil { + return "" + } + + targetID := TargetUserID(memberID) + if targetID == 0 { + return "" + } + + common.BotSession.ThreadMemberAdd(tID, discordgo.StrID(targetID)) + return "" +} + +func (c *Context) tmplCloseThread(channel interface{}, flags ...bool) (string, error) { + + if c.IncreaseCheckCallCounter("edit_channel", 10) { + return "", ErrTooManyCalls + } + + cID := c.ChannelArg(channel) + if cID == 0 { + return "", nil //dont send an error, a nil output would indicate invalid/unknown channel + } + + if c.IncreaseCheckCallCounter("edit_channel_"+strconv.FormatInt(cID, 10), 2) { + return "", ErrTooManyCalls + } + + cstate := c.GS.GetChannelOrThread(cID) + if cstate == nil { + return "", errors.New("thread not in state") + } + + if !cstate.Type.IsThread() { + return "", errors.New("must specify a thread") + } + + edit := &discordgo.ChannelEdit{} + switch len(flags) { + case 0: + archived := true + edit.Archived = &archived + case 1: + locked := true + edit.Locked = &locked + default: + return "", errors.New("too many flags") + } + + _, err := common.BotSession.ChannelEditComplex(cID, edit) + if err != nil { + return "", errors.New("unable to edit thread") + } + + return "", nil } func (c *Context) tmplCreateThread(channel, msgID, name interface{}, optionals ...interface{}) (*CtxChannel, error) { @@ -1029,11 +1087,22 @@ func (c *Context) tmplCreateThread(channel, msgID, name interface{}, optionals . } tstate := dstate.ChannelStateFromDgo(ctxThread) - c.AddThreadToGuildSet(&tstate) + c.addThreadToGuildSet(&tstate) return CtxChannelFromCS(&tstate), nil } +func (c *Context) addThreadToGuildSet(t *dstate.ChannelState) { + // Perform a copy so we don't mutate global array + gsCopy := *c.GS + gsCopy.Threads = make([]dstate.ChannelState, len(c.GS.Threads), len(c.GS.Threads)+1) + copy(gsCopy.Threads, c.GS.Threads) + + // Add new thread to copied guild state + gsCopy.Threads = append(gsCopy.Threads, *t) + c.GS = &gsCopy +} + // This function can delete both basic threads and forum threads func (c *Context) tmplDeleteThread(thread interface{}) (string, error) { if c.IncreaseCheckCallCounterPremium("delete_thread", 1, 1) { @@ -1054,28 +1123,93 @@ func (c *Context) tmplDeleteThread(thread interface{}) (string, error) { return "", nil } -func (c *Context) tmplThreadMemberAdd(threadID, memberID interface{}) string { - if c.IncreaseCheckGenericAPICall() { - return "" +func (c *Context) tmplEditThread(channel interface{}, args ...interface{}) (string, error) { + + if c.IncreaseCheckCallCounter("edit_channel", 10) { + return "", ErrTooManyCalls } - tID := c.ChannelArg(threadID) - if tID == 0 { - return "" + cID := c.ChannelArg(channel) + if cID == 0 { + return "", nil //dont send an error, a nil output would indicate invalid/unknown channel } - cstate := c.GS.GetThread(tID) + if c.IncreaseCheckCallCounter("edit_channel_"+strconv.FormatInt(cID, 10), 2) { + return "", ErrTooManyCalls + } + + cstate := c.GS.GetChannelOrThread(cID) if cstate == nil { - return "" + return "", errors.New("thread not in state") } - targetID := TargetUserID(memberID) - if targetID == 0 { - return "" + if !cstate.Type.IsThread() { + return "", errors.New("must specify a thread") } - common.BotSession.ThreadMemberAdd(tID, discordgo.StrID(targetID)) - return "" + partialThread, err := processThreadArgs(false, cstate, args...) + if err != nil { + return "", err + } + + edit := &discordgo.ChannelEdit{} + if partialThread.RateLimitPerUser != nil { + edit.RateLimitPerUser = partialThread.RateLimitPerUser + } + if partialThread.AppliedTags != nil { + edit.AppliedTags = *partialThread.AppliedTags + } + if partialThread.AutoArchiveDuration != nil { + edit.AutoArchiveDuration = *partialThread.AutoArchiveDuration + } + if partialThread.Invitable != nil { + edit.Invitable = partialThread.Invitable + } + + _, err = common.BotSession.ChannelEditComplex(cID, edit) + if err != nil { + return "", errors.New("unable to edit thread") + } + + return "", nil +} + +func (c *Context) tmplOpenThread(channel interface{}) (string, error) { + + if c.IncreaseCheckCallCounter("edit_channel", 10) { + return "", ErrTooManyCalls + } + + cID := c.ChannelArg(channel) + if cID == 0 { + return "", nil //dont send an error, a nil output would indicate invalid/unknown channel + } + + if c.IncreaseCheckCallCounter("edit_channel_"+strconv.FormatInt(cID, 10), 2) { + return "", ErrTooManyCalls + } + + cstate := c.GS.GetChannelOrThread(cID) + if cstate == nil { + return "", errors.New("thread not in state") + } + + if !cstate.Type.IsThread() { + return "", errors.New("must specify a thread") + } + + falseVar := false + edit := &discordgo.ChannelEdit{ + Archived: &falseVar, + Locked: &falseVar, + } + + _, err := common.BotSession.ChannelEditComplex(cID, edit) + if err != nil { + return "", errors.New("unable to edit thread") + } + + return "", nil } func (c *Context) tmplThreadMemberRemove(threadID, memberID interface{}) string { @@ -1102,7 +1236,73 @@ func (c *Context) tmplThreadMemberRemove(threadID, memberID interface{}) string return "" } -func ConvertTagNameToId(c *dstate.ChannelState, tagName string) int64 { +func (c *Context) tmplCreateForumPost(channel, name, content interface{}, optional ...interface{}) (*CtxChannel, error) { + + // shares same counter as create thread + if c.IncreaseCheckCallCounterPremium("create_thread", 1, 1) { + return nil, ErrTooManyCalls + } + + if content == nil { + return nil, errors.New("post content must not be nil") + } + + cID := c.ChannelArg(channel) + if cID == 0 { + return nil, nil //dont send an error, a nil output would indicate invalid/unknown channel + } + + cstate := c.GS.GetChannel(cID) + if cstate == nil { + return nil, errors.New("channel not in state") + } + + if cstate.Type != discordgo.ChannelTypeGuildForum { + return nil, errors.New("must specify a forum channel") + } + + partialThreaad, err := processThreadArgs(true, cstate, optional...) + if err != nil { + return nil, err + } + + start := &discordgo.ThreadStart{ + Name: ToString(name), + Type: discordgo.ChannelTypeGuildPublicThread, + Invitable: false, + RateLimitPerUser: *partialThreaad.RateLimitPerUser, + AppliedTags: *partialThreaad.AppliedTags, + } + + var msgData *discordgo.MessageSend + switch v := content.(type) { + case string: + if len(v) == 0 { + return nil, errors.New("post content must be non-zero length") + } + msgData, _ = CreateMessageSend("content", v) + case *discordgo.MessageEmbed: + msgData, _ = CreateMessageSend("embed", v) + case *discordgo.MessageSend: + msgData = v + default: + return nil, errors.New("post content must be string, embed, or complex message") + } + + thread, err := common.BotSession.ForumThreadStartComplex(cID, start, msgData) + if err != nil { + return nil, errors.New("unable to create forum post") + } + + tstate := dstate.ChannelStateFromDgo(thread) + tstate.AppliedTags = *partialThreaad.AppliedTags + c.addThreadToGuildSet(&tstate) + + return CtxChannelFromCS(&tstate), nil +} + +func tagIDFromName(c *dstate.ChannelState, tagName string) int64 { + if c.AvailableTags == nil { return 0 } @@ -1117,36 +1317,55 @@ func ConvertTagNameToId(c *dstate.ChannelState, tagName string) int64 { return 0 } -func ProcessOptionalForumPostArgs(c *dstate.ChannelState, values ...interface{}) (int, []int64, error) { +type partialThread struct { + RateLimitPerUser *int + AppliedTags *[]int64 + AutoArchiveDuration *int + Invitable *bool +} + +// Accepts a parent channel and key-value pair arguments. Returns a partial +// channel object with values set according to passed values. +func processThreadArgs(newThread bool, parent *dstate.ChannelState, values ...interface{}) (*partialThread, error) { + + c := &partialThread{} + if newThread { + c = &partialThread{ + RateLimitPerUser: &parent.DefaultThreadRateLimitPerUser, + AppliedTags: &[]int64{}, + } + } + if len(values) == 0 { - return c.DefaultThreadRateLimitPerUser, nil, nil + return c, nil } threadSdict, err := StringKeyDictionary(values...) if err != nil { - return 0, nil, err + return c, err } - rateLimit := c.DefaultThreadRateLimitPerUser - var tags []int64 = nil for key, val := range threadSdict { key = strings.ToLower(key) switch key { case "slowmode": - rateLimit = tmplToInt(val) + ratelimit := tmplToInt(val) + c.RateLimitPerUser = &ratelimit case "tags": - if c.AvailableTags == nil { + if parent.AvailableTags == nil { break } + var tags []int64 v, _ := indirect(reflect.ValueOf(val)) const maxTags = 5 // discord limit if v.Kind() == reflect.String { - tag := ConvertTagNameToId(c, ToString(val)) + tag := tagIDFromName(parent, ToString(val)) // ensure supplied id is valid if tag > 0 { tags = []int64{tag} + c.AppliedTags = &tags } } else if v.Kind() == reflect.Slice { // used to get rid of any duplicate tags the user might have sent @@ -1169,7 +1388,7 @@ func ProcessOptionalForumPostArgs(c *dstate.ChannelState, values ...interface{}) } // try to convert and check if the id is valid - tag := ConvertTagNameToId(c, name) + tag := tagIDFromName(parent, name) if tag == 0 { continue } @@ -1177,80 +1396,84 @@ func ProcessOptionalForumPostArgs(c *dstate.ChannelState, values ...interface{}) seen[name] = struct{}{} tags = append(tags, tag) } + c.AppliedTags = &tags } else { - return 0, nil, errors.New("tags must be of type string or cslice") + return c, errors.New("`tags` must be of type string or cslice") } + case "auto_archive_duration": + duration := tmplToInt(val) + if duration < 60 || duration > 10080 { + return c, errors.New("'auto_archive_duration' must be and integer between 60 and 10080") + } + c.AutoArchiveDuration = &duration + case "invitable": + val, ok := val.(bool) + if ok { + invitable := val + c.Invitable = &invitable + continue + } + return c, errors.New("'invitable' must be a boolean") default: - return 0, nil, errors.New(`invalid key "` + key + `"`) + return c, errors.New(`invalid key "` + key + `"`) } } - return rateLimit, tags, nil + return c, nil } -func (c *Context) tmplCreateForumPost(channel, name, content interface{}, optional ...interface{}) (*CtxChannel, error) { - // shares same counter as create thread - if c.IncreaseCheckCallCounterPremium("create_thread", 1, 1) { - return nil, ErrTooManyCalls - } +func (c *Context) tmplPinForumPost(unpin bool) func(channel interface{}) (string, error) { + return func(channel interface{}) (string, error) { - if content == nil { - return nil, errors.New("post content must not be nil") - } + if c.IncreaseCheckCallCounter("edit_channel", 10) { + return "", ErrTooManyCalls + } - cID := c.ChannelArg(channel) - if cID == 0 { - return nil, nil // dont send an error, a nil output would indicate invalid/unknown channel - } + cID := c.ChannelArg(channel) + if cID == 0 { + return "", nil //dont send an error, a nil output would indicate invalid/unknown channel + } - cstate := c.GS.GetChannel(cID) - if cstate == nil { - return nil, errors.New("channel not in state") - } + if c.IncreaseCheckCallCounter("edit_channel_"+strconv.FormatInt(cID, 10), 2) { + return "", ErrTooManyCalls + } - if !cstate.Type.IsForum() { - return nil, errors.New("must specify a forum channel") - } + cstate := c.GS.GetChannelOrThread(cID) + if cstate == nil { + return "", errors.New("forum post not in state") + } - rateLimit, tags, err := ProcessOptionalForumPostArgs(cstate, optional...) - if err != nil { - return nil, err - } + if !cstate.Type.IsThread() { + return "", errors.New("must specify a forum post") + } - start := &discordgo.ThreadStart{ - Name: ToString(name), - Type: discordgo.ChannelTypeGuildPublicThread, - Invitable: false, - RateLimitPerUser: rateLimit, - AppliedTags: tags, - } + parentCState := c.GS.GetChannel(cstate.ParentID) + if parentCState == nil { + return "", errors.New("parent channel not in state") + } - var msgData *discordgo.MessageSend - switch v := content.(type) { - case string: - if len(v) == 0 { - return nil, errors.New("post content must be non-zero length") + if parentCState.Type != discordgo.ChannelTypeGuildForum { + return "", errors.New("must specify a forum post") } - msgData, _ = CreateMessageSend("content", v) - case *discordgo.MessageEmbed: - msgData, _ = CreateMessageSend("embed", v) - case *discordgo.MessageSend: - msgData = v - default: - return nil, errors.New("post content must be string, embed, or complex message") - } - thread, err := common.BotSession.ForumThreadStartComplex(cID, start, msgData) - if err != nil { - return nil, errors.New("unable to create forum post") - } + edit := &discordgo.ChannelEdit{} - tstate := dstate.ChannelStateFromDgo(thread) - tstate.AppliedTags = tags - c.AddThreadToGuildSet(&tstate) + flags := cstate.Flags + if unpin { + flags = flags &^ discordgo.ChannelFlagsPinned + } else { + flags |= discordgo.ChannelFlagsPinned + } + edit.Flags = &flags - return CtxChannelFromCS(&tstate), nil + _, err := common.BotSession.ChannelEditComplex(cID, edit) + if err != nil { + return "", errors.New("unable to edit forum post") + } + + return "", nil + } } func (c *Context) tmplGetChannelOrThread(channel interface{}) (*CtxChannel, error) { diff --git a/common/templates/structs.go b/common/templates/structs.go index 48283e64b0..ed0de62c59 100644 --- a/common/templates/structs.go +++ b/common/templates/structs.go @@ -27,8 +27,11 @@ type CtxChannel struct { ParentID int64 `json:"parent_id"` OwnerID int64 `json:"owner_id"` - AvailableTags []discordgo.ForumTag `json:"available_tags"` - AppliedTags []int64 `json:"applied_tags"` + AvailableTags []discordgo.ForumTag `json:"available_tags"` + AppliedTags []int64 `json:"applied_tags"` + Flags discordgo.ChannelFlags `json:"flags"` + Archived bool `json:"archived"` + Locked bool `json:"locked"` } func (c *CtxChannel) Mention() (string, error) { @@ -49,7 +52,7 @@ func CtxChannelFromCS(cs *dstate.ChannelState) *CtxChannel { ID: cs.ID, IsPrivate: cs.IsPrivate(), IsThread: cs.Type.IsThread(), - IsForum: cs.Type.IsForum(), + IsForum: cs.Type == discordgo.ChannelTypeGuildForum, GuildID: cs.GuildID, Name: cs.Name, Type: cs.Type, @@ -62,6 +65,9 @@ func CtxChannelFromCS(cs *dstate.ChannelState) *CtxChannel { OwnerID: cs.OwnerID, AvailableTags: cs.AvailableTags, AppliedTags: cs.AppliedTags, + Flags: cs.Flags, + Archived: cs.Archived, + Locked: cs.Locked, } return ctxChannel diff --git a/lib/discordgo/structs.go b/lib/discordgo/structs.go index 75f9c31774..9a9e184fea 100644 --- a/lib/discordgo/structs.go +++ b/lib/discordgo/structs.go @@ -200,10 +200,6 @@ func (t ChannelType) IsThread() bool { return t == ChannelTypeGuildPrivateThread || t == ChannelTypeGuildPublicThread } -func (t ChannelType) IsForum() bool { - return t == ChannelTypeGuildForum -} - // A Channel holds all data related to an individual Discord channel. type Channel struct { // The ID of the channel. @@ -264,6 +260,9 @@ type Channel struct { // Thread specific fields ThreadMetadata *ThreadMetadata `json:"thread_metadata"` + // ChannelFlags combined as a bitfield. + Flags ChannelFlags `json:"flags"` + // The set of tags that can be used in a forum channel. AvailableTags []ForumTag `json:"available_tags"` @@ -284,6 +283,18 @@ type Channel struct { // The default forum layout view used to display posts in forum channels. // Defaults to ForumLayoutNotSet, which indicates a layout view has not been set by a channel admin. DefaultForumLayout ForumLayout `json:"default_forum_layout"` + + // whether the thread is archived + Archived bool `json:"archived"` + + // the thread will stop showing in the channel list after auto_archive_duration minutes of inactivity, can be set to: 60, 1440, 4320, 10080 + AutoArchiveDuration int `json:"auto_archive_duration,omitempty"` + + // whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it + Locked bool `json:"locked"` + + // whether non-moderators can add other non-moderators to a thread; only available on private threads + Invitable bool `json:"invitable"` } func (c *Channel) GetChannelID() int64 { @@ -299,7 +310,21 @@ func (c *Channel) Mention() string { return fmt.Sprintf("<#%d>", c.ID) } -// A ChannelEdit holds Channel Feild data for a channel edit. +// ChannelFlags is the flags of "channel" (see ChannelFlags* consts) +// https://discord.com/developers/docs/resources/channel#message-object-message-flags +type ChannelFlags int + +// Valid ChannelFlags values +const ( + // ChannelFlagsPinned this thread is pinned to the top of its parent GUILD_FORUM or GUILD_MEDIA channel. + ChannelFlagsPinned ChannelFlags = 1 << 1 + // ChannelFlagsRequireTag whether a tag is required to be specified when creating a thread in a GUILD_FORUM or a GUILD_MEDIA channel. Tags are specified in the applied_tags field. + ChannelFlagsRequireTag ChannelFlags = 1 << 4 + // ChannelFlagsHideMediaDownloadOptions when set hides the embedded media download options. Available only for media channels. + ChannelFlagsHideMediaDownloadOptions ChannelFlags = 1 << 15 +) + +// A ChannelEdit holds Channel Field data for a channel edit. type ChannelEdit struct { Name string `json:"name,omitempty"` Topic string `json:"topic,omitempty"` @@ -309,7 +334,17 @@ type ChannelEdit struct { UserLimit int `json:"user_limit,omitempty"` PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites,omitempty"` ParentID *null.String `json:"parent_id,omitempty"` + Flags *ChannelFlags `json:"flags,omitempty"` RateLimitPerUser *int `json:"rate_limit_per_user,omitempty"` + + // Threads only + Archived *bool `json:"archived,omitempty"` + AutoArchiveDuration int `json:"auto_archive_duration,omitempty"` + Locked *bool `json:"locked,omitempty"` + Invitable *bool `json:"invitable,omitempty"` + + // NOTE: forum threads only - these are IDs + AppliedTags IDSlice `json:"applied_tags,string,omitempty"` } type RoleCreate struct { diff --git a/lib/dstate/helpers.go b/lib/dstate/helpers.go index 13a5cedc7b..090529672c 100644 --- a/lib/dstate/helpers.go +++ b/lib/dstate/helpers.go @@ -161,6 +161,7 @@ func ChannelStateFromDgo(c *discordgo.Channel) ChannelState { NSFW: c.NSFW, Position: c.Position, Bitrate: c.Bitrate, + Flags: c.Flags, OwnerID: c.OwnerID, AvailableTags: c.AvailableTags, AppliedTags: c.AppliedTags, @@ -168,6 +169,10 @@ func ChannelStateFromDgo(c *discordgo.Channel) ChannelState { DefaultThreadRateLimitPerUser: c.DefaultThreadRateLimitPerUser, DefaultSortOrder: c.DefaultSortOrder, DefaultForumLayout: c.DefaultForumLayout, + Archived: c.Archived, + AutoArchiveDuration: c.AutoArchiveDuration, + Locked: c.Locked, + Invitable: c.Invitable, } } diff --git a/lib/dstate/interface.go b/lib/dstate/interface.go index d48a458e5f..e6ef05b63e 100644 --- a/lib/dstate/interface.go +++ b/lib/dstate/interface.go @@ -259,13 +259,18 @@ type ChannelState struct { UserLimit int `json:"user_limit"` ParentID int64 `json:"parent_id,string"` RateLimitPerUser int `json:"rate_limit_per_user"` + Flags discordgo.ChannelFlags `json:"flags"` OwnerID int64 `json:"owner_id,string"` ThreadMetadata *discordgo.ThreadMetadata `json:"thread_metadata"` PermissionOverwrites []discordgo.PermissionOverwrite `json:"permission_overwrites"` - AvailableTags []discordgo.ForumTag `json:"available_tags"` - AppliedTags []int64 `json:"applied_tags"` + AvailableTags []discordgo.ForumTag `json:"available_tags"` + AppliedTags []int64 `json:"applied_tags"` + Archived bool `json:"archived"` + AutoArchiveDuration int `json:"auto_archive_duration,omitempty"` + Locked bool `json:"locked"` + Invitable bool `json:"invitable"` DefaultReactionEmoji discordgo.ForumDefaultReaction `json:"default_reaction_emoji"` DefaultThreadRateLimitPerUser int `json:"default_thread_rate_limit_per_user"`