Skip to content

Commit

Permalink
🐛 fix contains on dict type (#5138)
Browse files Browse the repository at this point in the history
* 🐛 fix `contains` on `dict` type

We are experiencing issues where two `dict` types being compared don't
come up with the right results. Here is an example I ran to produce
this:

```coffee
> region = terraform.plan.variables.where( name == "gcp_region" ).first.value;
  terraform.plan.variables.where(name == "service_vpc_connector").all(v: v.value.contains(region) )
[failed] [].all()
  actual:   [
    0: terraform.plan.variable value="projects/prj-acc7/locations/us-central1/connectors/preprod-gen-central1-01" name="service_vpc_connector"
  ]
```

As you can see above we are looking at the `value` of the varaible,
which clearly is set to the region but doesn't come up. If we print out
the variables individually, it all comes out correctly.

The problem is that the string comparison doesn't work when we work with
`dict` types because in that example we have to create code that has to
with with arbitrary content and work with filtering on another `dict`.

This PR fixes the issue.

```coffee
> region = terraform.plan.variables.where( name == "gcp_region" ).first.value;
  terraform.plan.variables.where(name == "service_vpc_connector").all(v: v.value.contains(region) )
[ok] value: true
```

* 🐛 add missing string.contains(dict)

Signed-off-by: Dominik Richter <[email protected]>

* 🐛 add missing string.contains(array(dict))

Signed-off-by: Dominik Richter <[email protected]>

---------

Signed-off-by: Dominik Richter <[email protected]>
  • Loading branch information
arlimus authored Jan 31, 2025
1 parent 4c1b4a0 commit dec5f92
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 94 deletions.
14 changes: 9 additions & 5 deletions llx/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,12 @@ func init() {
string("||" + types.MapLike): {f: stringOrMapV2, Label: "&&"},
string("+" + types.String): {f: stringPlusStringV2, Label: "+"},
// fields
string("contains" + types.String): {f: stringContainsStringV2, Label: "contains"},
string("contains" + types.Array(types.String)): {f: stringContainsArrayStringV2, Label: "contains"},
string("contains" + types.Int): {f: stringContainsIntV2, Label: "contains"},
string("contains" + types.Array(types.Int)): {f: stringContainsArrayIntV2, Label: "contains"},
string("contains" + types.String): {f: stringContainsString, Label: "contains"},
string("contains" + types.Array(types.String)): {f: stringContainsArrayString, Label: "contains"},
string("contains" + types.Dict): {f: stringContainsDict, Label: "contains"},
string("contains" + types.Array(types.Dict)): {f: stringContainsArrayDict, Label: "contains"},
string("contains" + types.Int): {f: stringContainsInt, Label: "contains"},
string("contains" + types.Array(types.Int)): {f: stringContainsArrayInt, Label: "contains"},
string("contains" + types.Regex): {f: stringContainsRegex, Label: "contains"},
string("contains" + types.Array(types.Regex)): {f: stringContainsArrayRegex, Label: "contains"},
string("in"): {f: stringInArray, Label: "in"},
Expand All @@ -349,7 +351,7 @@ func init() {
// string("!=" + types.Int): {f: stringNotIntV2, Label: "!="},
// string("==" + types.Float): {f: stringCmpFloatV2, Label: "=="},
// string("!=" + types.Float): {f: stringNotFloatV2, Label: "!="},
// string("==" + types.Dict): {f: stringCmpDictV2, Label: "=="},
string("==" + types.Dict): {f: stringsliceEqDict, Label: "=="},
// string("!=" + types.Dict): {f: stringNotDictV2, Label: "!="},
string("==" + types.Array(types.String)): {f: stringsliceEqArrayString, Label: "=="},
},
Expand Down Expand Up @@ -551,6 +553,8 @@ func init() {
"containsNone": {f: dictContainsNone},
string("contains" + types.String): {f: dictContainsStringV2, Label: "contains"},
string("contains" + types.Array(types.String)): {f: dictContainsArrayStringV2, Label: "contains"},
string("contains" + types.Dict): {f: dictContainsDict, Label: "contains"},
string("contains" + types.Array(types.Dict)): {f: dictContainsArrayDict, Label: "contains"},
string("contains" + types.Int): {f: dictContainsIntV2, Label: "contains"},
string("contains" + types.Array(types.Int)): {f: dictContainsArrayIntV2, Label: "contains"},
string("contains" + types.Regex): {f: dictContainsRegex, Label: "contains"},
Expand Down
55 changes: 53 additions & 2 deletions llx/builtin_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,30 @@ func anyContainsString(an interface{}, s string) bool {
}
}

func anyContainsAny(an any, s any) (bool, error) {
if an == nil {
return false, nil
}

switch x := an.(type) {
case string:
return opStringContainsDict(x, s)
case []interface{}:
for i := range x {
ok, err := anyContainsAny(x[i], s)
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
return false, nil
default:
return false, nil
}
}

func anyContainsRegex(an interface{}, re *regexp.Regexp) bool {
if an == nil {
return false
Expand Down Expand Up @@ -1408,6 +1432,24 @@ func dictContainsStringV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uin
return BoolData(ok), 0, nil
}

func dictContainsDict(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
argRef := chunk.Function.Args[0]
arg, rref, err := e.resolveValue(argRef, ref)
if err != nil || rref > 0 {
return nil, rref, err
}

if arg.Value == nil {
return BoolFalse, 0, nil
}

ok, err := anyContainsAny(bind.Value, arg.Value)
if err != nil {
return BoolData(false), 0, err
}
return BoolData(ok), 0, nil
}

func dictContainsIntV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
argRef := chunk.Function.Args[0]
arg, rref, err := e.resolveValue(argRef, ref)
Expand Down Expand Up @@ -1449,7 +1491,16 @@ func dictContainsRegex(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64
func dictContainsArrayStringV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
switch bind.Value.(type) {
case string:
return stringContainsArrayStringV2(e, bind, chunk, ref)
return stringContainsArrayString(e, bind, chunk, ref)
default:
return nil, 0, errors.New("dict value does not support field `contains`")
}
}

func dictContainsArrayDict(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
switch bind.Value.(type) {
case string:
return stringContainsArrayString(e, bind, chunk, ref)
default:
return nil, 0, errors.New("dict value does not support field `contains`")
}
Expand All @@ -1458,7 +1509,7 @@ func dictContainsArrayStringV2(e *blockExecutor, bind *RawData, chunk *Chunk, re
func dictContainsArrayIntV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
switch bind.Value.(type) {
case string:
return stringContainsArrayIntV2(e, bind, chunk, ref)
return stringContainsArrayInt(e, bind, chunk, ref)
default:
return nil, 0, errors.New("dict value does not support field `contains`")
}
Expand Down
165 changes: 78 additions & 87 deletions llx/builtin_simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -2078,47 +2078,49 @@ func mapOrTimeV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*Ra

// string methods

func stringContainsStringV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
if bind.Value == nil {
return BoolFalse, 0, nil
}

argRef := chunk.Function.Args[0]
arg, rref, err := e.resolveValue(argRef, ref)
if err != nil || rref > 0 {
return nil, rref, err
func opStringContainsDict(left string, right any) (bool, error) {
switch x := right.(type) {
case string:
return strings.Contains(left, x), nil
case int64:
val := strconv.FormatInt(x, 10)
return strings.Contains(left, val), nil
case float64:
val := strconv.FormatFloat(x, 'f', -1, 64)
return strings.Contains(left, val), nil
default:
return false, nil
}
}

if arg.Value == nil {
return BoolFalse, 0, nil
func opStringContainsString(left string, right any) (bool, error) {
v, ok := right.(string)
if !ok {
return false, nil
}

ok := strings.Contains(bind.Value.(string), arg.Value.(string))
return BoolData(ok), 0, nil
return strings.Contains(left, v), nil
}

func stringContainsIntV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
if bind.Value == nil {
return BoolFalse, 0, nil
}
func opStringContainsInt(left string, right any) (bool, error) {
val := strconv.FormatInt(right.(int64), 10)
return strings.Contains(left, val), nil
}

argRef := chunk.Function.Args[0]
arg, rref, err := e.resolveValue(argRef, ref)
if err != nil || rref > 0 {
return nil, rref, err
func opStringContainsRegex(left string, right any) (bool, error) {
reContent, ok := right.(string)
if !ok {
return false, nil
}

if arg.Value == nil {
return BoolFalse, 0, nil
re, err := regexp.Compile(right.(string))
if err != nil {
return false, errors.New("Failed to compile regular expression: " + reContent)
}

val := strconv.FormatInt(arg.Value.(int64), 10)

ok := strings.Contains(bind.Value.(string), val)
return BoolData(ok), 0, nil
return re.MatchString(left), nil
}

func stringContainsRegex(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
func stringContainsOther(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64, cmp func(left string, right any) (bool, error)) (*RawData, uint64, error) {
if bind.Value == nil {
return BoolFalse, 0, nil
}
Expand All @@ -2133,17 +2135,15 @@ func stringContainsRegex(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint
return BoolFalse, 0, nil
}

reContent := arg.Value.(string)
re, err := regexp.Compile(reContent)
left := bind.Value.(string)
res, err := cmp(left, arg.Value)
if err != nil {
return nil, 0, errors.New("Failed to compile regular expression: " + reContent)
return nil, 0, err
}

ok := re.MatchString(bind.Value.(string))
return BoolData(ok), 0, nil
return BoolData(res), 0, nil
}

func stringContainsArrayStringV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
func stringContainsArrayOther(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64, cmp func(left string, right any) (bool, error)) (*RawData, uint64, error) {
if bind.Value == nil {
return BoolFalse, 0, nil
}
Expand All @@ -2158,73 +2158,50 @@ func stringContainsArrayStringV2(e *blockExecutor, bind *RawData, chunk *Chunk,
return BoolFalse, 0, nil
}

arr := arg.Value.([]interface{})
arr := arg.Value.([]any)
for i := range arr {
v := arr[i].(string)
if strings.Contains(bind.Value.(string), v) {
found, err := cmp(bind.Value.(string), arr[i])
if err != nil {
return nil, 0, err
}
if found {
return BoolData(true), 0, nil
}
}

return BoolData(false), 0, nil
}

func stringContainsArrayIntV2(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
if bind.Value == nil {
return BoolFalse, 0, nil
}

argRef := chunk.Function.Args[0]
arg, rref, err := e.resolveValue(argRef, ref)
if err != nil || rref > 0 {
return nil, rref, err
}

if arg.Value == nil {
return BoolFalse, 0, nil
}

arr := arg.Value.([]interface{})
for i := range arr {
v := arr[i].(int64)
val := strconv.FormatInt(v, 10)
if strings.Contains(bind.Value.(string), val) {
return BoolData(true), 0, nil
}
}
func stringContainsString(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsOther(e, bind, chunk, ref, opStringContainsString)
}

return BoolData(false), 0, nil
func stringContainsDict(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsOther(e, bind, chunk, ref, opStringContainsDict)
}

func stringContainsArrayRegex(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
if bind.Value == nil {
return BoolFalse, 0, nil
}
func stringContainsInt(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsOther(e, bind, chunk, ref, opStringContainsInt)
}

argRef := chunk.Function.Args[0]
arg, rref, err := e.resolveValue(argRef, ref)
if err != nil || rref > 0 {
return nil, rref, err
}
func stringContainsRegex(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsOther(e, bind, chunk, ref, opStringContainsRegex)
}

if arg.Value == nil {
return BoolFalse, 0, nil
}
func stringContainsArrayString(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsArrayOther(e, bind, chunk, ref, opStringContainsString)
}

arr := arg.Value.([]interface{})
for i := range arr {
v := arr[i].(string)
re, err := regexp.Compile(v)
if err != nil {
return nil, 0, errors.New("Failed to compile regular expression: " + v)
}
func stringContainsArrayDict(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsArrayOther(e, bind, chunk, ref, opStringContainsDict)
}

if re.MatchString(bind.Value.(string)) {
return BoolTrue, 0, nil
}
}
func stringContainsArrayInt(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsArrayOther(e, bind, chunk, ref, opStringContainsInt)
}

return BoolFalse, 0, nil
func stringContainsArrayRegex(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return stringContainsArrayOther(e, bind, chunk, ref, opStringContainsRegex)
}

func stringInArray(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
Expand Down Expand Up @@ -2555,6 +2532,20 @@ func stringsliceEqString(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint
})
}

func stringsliceEqDict(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return dataOpV2(e, bind, chunk, ref, types.Int, func(left interface{}, right interface{}) *RawData {
l := left.(string)
ok, err := opStringContainsDict(l, right)
if err != nil {
return &RawData{Error: err, Type: types.String}
}
if !ok {
return StringData("")
}
return StringData(l)
})
}

func stringsliceEqArrayString(e *blockExecutor, bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
return dataOpV2(e, bind, chunk, ref, types.Int, func(left interface{}, right interface{}) *RawData {
l := left.(string)
Expand Down
30 changes: 30 additions & 0 deletions mql/mql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,36 @@ func TestDictMethods(t *testing.T) {
x.TestSimple(t, []testutils.SimpleTest{
{
Code: "muser.dict.nonexisting.contains('abc')",
ResultIndex: 3,
Expectation: false,
},
{
Code: "muser.dict.string.contains(muser.dict.string2)",
ResultIndex: 3,
Expectation: false,
},
{
Code: "muser.dict.string.contains(muser.dict.string)",
ResultIndex: 3,
Expectation: true,
},
{
Code: "'<< hello world >>'.contains(muser.dict.string)",
ResultIndex: 1,
Expectation: true,
},
{
Code: "'<< hello + world >>'.contains(muser.dict.string)",
ResultIndex: 1,
Expectation: false,
},
{
Code: "'<< hello world >>'.contains([muser.dict.string])",
ResultIndex: 1,
Expectation: true,
},
{
Code: "'<< hello + world >>'.contains([muser.dict.string])",
ResultIndex: 1,
Expectation: false,
},
Expand Down
Loading

0 comments on commit dec5f92

Please sign in to comment.