Skip to content

Commit

Permalink
feat: add 'url list' command and improve command responses
Browse files Browse the repository at this point in the history
  • Loading branch information
cyb3rko committed Dec 26, 2024
1 parent 6da276c commit c631b73
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 21 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Guardian supports URL filtering based on a customizable domain list.
**Examples**:
- `!gd url block t.me`
- `!gd url unblock t.me`
- `!gd url list`

### URL Phishing Check 🗡

Expand Down
85 changes: 69 additions & 16 deletions db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ type mimetype struct {
count int
}

type domain struct {
name string
count int
}

var tables = []table{
{"domains", "name TEXT PRIMARY KEY, count INT"},
{"mimetypes", "name TEXT PRIMARY KEY, count INT"},
Expand All @@ -34,10 +39,8 @@ func InitDB() *sql.DB {
}

func IsDomainBlocked(db *sql.DB, domain string) bool {
query := db.QueryRow("SELECT count FROM domains WHERE name = ?", domain)
var count int
err := query.Scan(&count)
if err != nil {
found, count := findDomain(db, domain)
if !found {
// not found in database, implicitly allowed
return false
}
Expand All @@ -48,10 +51,8 @@ func IsDomainBlocked(db *sql.DB, domain string) bool {
}

func IsMimeBlocked(db *sql.DB, mime string) bool {
query := db.QueryRow("SELECT count FROM mimetypes WHERE name = ?", mime)
var count int
err := query.Scan(&count)
if err != nil {
found, count := findMime(db, mime)
if !found {
// not found in database, implicitly allowed
return false
}
Expand All @@ -61,24 +62,69 @@ func IsMimeBlocked(db *sql.DB, mime string) bool {
return true
}

func BlockDomain(db *sql.DB, domain string) bool {
func BlockDomain(db *sql.DB, domain string) (bool, string) {
found, _ := findDomain(db, domain)
if found {
return false, "Domain already blocked"
}
_, err := db.Exec("INSERT INTO domains (name, count) values (?, 0)", domain)
return err == nil
return err == nil, ""
}

func UnblockDomain(db *sql.DB, domain string) bool {
func UnblockDomain(db *sql.DB, domain string) (bool, string) {
found, _ := findDomain(db, domain)
if !found {
return false, "Domain not blocked"
}
_, err := db.Exec("DELETE FROM domains WHERE name = ?", domain)
return err == nil
if err == nil {
return true, ""
} else {
return false, fmt.Sprintf("Unblocking domain failed:\n%s", err)
}
}

func BlockMime(db *sql.DB, mime string) bool {
func ListDomains(db *sql.DB) ([]string, error) {
query, err := db.Query("SELECT name, count FROM domains ORDER BY count DESC")
if err != nil {
return nil, err
}
var rows []string
for query.Next() {
var row domain
_ = query.Scan(&row.name, &row.count)
rows = append(rows, fmt.Sprintf("- %s (%d)", row.name, row.count))
}
return rows, nil
}

func findDomain(db *sql.DB, domain string) (bool, int) {
query := db.QueryRow("SELECT count FROM domains WHERE name = ?", domain)
var count int
err := query.Scan(&count)
return err == nil, count
}

func BlockMime(db *sql.DB, mime string) (bool, string) {
found, _ := findMime(db, mime)
if found {
return false, "MIME type already blocked"
}
_, err := db.Exec("INSERT INTO mimetypes (name, count) values (?, 0)", mime)
return err == nil
return err == nil, ""
}

func UnblockMime(db *sql.DB, mime string) bool {
func UnblockMime(db *sql.DB, mime string) (bool, string) {
found, _ := findMime(db, mime)
if !found {
return false, "MIME type not blocked"
}
_, err := db.Exec("DELETE FROM mimetypes WHERE name = ?", mime)
return err == nil
if err == nil {
return true, ""
} else {
return false, fmt.Sprintf("Unblocking MIME type failed:\n%s", err)
}
}

func ListMimes(db *sql.DB) ([]string, error) {
Expand All @@ -95,6 +141,13 @@ func ListMimes(db *sql.DB) ([]string, error) {
return rows, nil
}

func findMime(db *sql.DB, mime string) (bool, int) {
query := db.QueryRow("SELECT count FROM mimetypes WHERE name = ?", mime)
var count int
err := query.Scan(&count)
return err == nil, count
}

func createAllTables(db *sql.DB) {
for _, tab := range tables {
createTable(db, tab.name, tab.values)
Expand Down
3 changes: 2 additions & 1 deletion help.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const urlHelp = "🛡️ <b>Guardian Help Page [url]</b> 🛡️:<br/>" +
"<code>!gd url <<args>></code><br/><br/>" +
"<b>Arguments</b>:<br/>" +
"<code>block <<domain>></code>: <i>Block domain in messages</i><br/>" +
"<code>unblock <<domain>></code>: <i>Unblock domain in messages</i>"
"<code>unblock <<domain>></code>: <i>Unblock domain in messages</i><br/>" +
"<code>list</code>: <i>List blocked domains in messages</i>"

const mimeHelp = "🛡️ <b>Guardian Help Page [mime]</b> 🛡️:<br/>" +
"<code>!gd mime <<args>></code><br/><br/>" +
Expand Down
48 changes: 44 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,40 @@ func onManagementMessage(client *mautrix.Client, ctx context.Context, evt *event
switch subcommands[0] {
case "block":
if len(subcommands) == 2 {
db.BlockDomain(database, subcommands[1])
success, response := db.BlockDomain(database, subcommands[1])
if success {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "✅")
} else {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "❌")
_, _ = client.SendNotice(ctx, evt.RoomID, response)
}
return
}
case "unblock":
if len(subcommands) == 2 {
db.UnblockDomain(database, subcommands[1])
success, response := db.UnblockDomain(database, subcommands[1])
if success {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "✅")
} else {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "❌")
_, _ = client.SendNotice(ctx, evt.RoomID, response)
}
return
}
case "list":
if len(subcommands) == 1 {
list, err := db.ListDomains(database)
if err != nil {
return
}
message := fmt.Sprintf(
"Configured domains to block:\n%s",
strings.Join(list, "\n"),
)
_, err = client.SendNotice(ctx, config.mngtRoomId, message)
if err != nil {
return
}
return
}
}
Expand All @@ -139,12 +167,24 @@ func onManagementMessage(client *mautrix.Client, ctx context.Context, evt *event
switch subcommands[0] {
case "block":
if len(subcommands) == 2 {
db.BlockMime(database, subcommands[1])
success, response := db.BlockMime(database, subcommands[1])
if success {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "✅")
} else {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "❌")
_, _ = client.SendNotice(ctx, evt.RoomID, response)
}
return
}
case "unblock":
if len(subcommands) == 2 {
db.UnblockMime(database, subcommands[1])
success, response := db.UnblockMime(database, subcommands[1])
if success {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "✅")
} else {
_, _ = client.SendReaction(ctx, evt.RoomID, evt.ID, "❌")
_, _ = client.SendNotice(ctx, evt.RoomID, response)
}
return
}
case "list":
Expand Down

0 comments on commit c631b73

Please sign in to comment.