Skip to content

Commit

Permalink
Add type parameter to Core.Literal.unwrap() method
Browse files Browse the repository at this point in the history
The method is now `<T> T Literal.unwrap(Class<T>)`, so that
you can request the value type it returns.

We now use `unwrap` for accessing all literal values. It
deals with both wrapped values (which may not implement
`Comparable`) and unwrapped values (which must implement
`Comparable`).
  • Loading branch information
julianhyde committed Dec 13, 2023
1 parent 9ec6ca3 commit 12353b8
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 30 deletions.
45 changes: 39 additions & 6 deletions src/main/java/net/hydromatic/morel/ast/Core.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -617,8 +619,38 @@ static Comparable wrap(Exp exp, Object value) {
return new Wrapper(exp, value);
}

public Object unwrap() {
return ((Wrapper) value).o;
/** Returns the value of this literal as a given class,
* or throws {@link ClassCastException}. If the class is not
* {@link Comparable}, the value will be in a wrapper. */
public <C> C unwrap(Class<C> clazz) {
Object v;
if (clazz.isInstance(value) && clazz != Object.class) {
v = value;
} else if (Number.class.isAssignableFrom(clazz)
&& value instanceof Number) {
Number number = (Number) value;
if (clazz == Double.class) {
v = number.doubleValue();
} else if (clazz == Float.class) {
v = number.floatValue();
} else if (clazz == Long.class) {
v = number.longValue();
} else if (clazz == Integer.class) {
v = number.intValue();
} else if (clazz == Short.class) {
v = number.shortValue();
} else if (clazz == Byte.class) {
v = number.byteValue();
} else if (clazz == BigInteger.class
&& number instanceof BigDecimal) {
v = ((BigDecimal) number).toBigIntegerExact();
} else {
v = value;
}
} else {
v = ((Wrapper) value).o;
}
return clazz.cast(v);
}

@Override public int hashCode() {
Expand Down Expand Up @@ -1157,8 +1189,8 @@ public static class Scan extends FromStep {
}

private boolean isLiteralTrue() {
return condition instanceof Literal
&& ((Literal) condition).value.equals(true);
return condition.op == Op.BOOL_LITERAL
&& ((Literal) condition).unwrap(Boolean.class);
}

public Scan copy(List<Binding> bindings, Pat pat, Exp exp, Exp condition) {
Expand Down Expand Up @@ -1370,7 +1402,7 @@ public List<Exp> args() {
@Override AstWriter unparse(AstWriter w, int left, int right) {
switch (fn.op) {
case FN_LITERAL:
final BuiltIn builtIn = (BuiltIn) ((Literal) fn).value;
final BuiltIn builtIn = ((Literal) fn).unwrap(BuiltIn.class);

// Because the Core language is narrower than AST, a few AST expression
// types do not exist in Core and are translated to function
Expand Down Expand Up @@ -1403,7 +1435,8 @@ public Apply copy(Exp fn, Exp arg) {
}

@Override public boolean isCallTo(BuiltIn builtIn) {
return fn.op == Op.FN_LITERAL && ((Literal) fn).value == builtIn;
return fn.op == Op.FN_LITERAL
&& ((Literal) fn).unwrap(BuiltIn.class) == builtIn;
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/main/java/net/hydromatic/morel/ast/CoreBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,7 @@ public Pair<Core.Exp, List<Core.Exp>> mergeExtents(TypeSystem typeSystem,
for (Core.Exp exp : exps) {
if (exp.isCallTo(BuiltIn.Z_EXTENT)) {
final Core.Literal argLiteral = (Core.Literal) ((Core.Apply) exp).arg;
final Core.Wrapper wrapper = (Core.Wrapper) argLiteral.value;
final RangeExtent list = wrapper.unwrap(RangeExtent.class);
final RangeExtent list = argLiteral.unwrap(RangeExtent.class);
rangeSet = intersect
? rangeSet.intersection(list.rangeSet)
: rangeSet.union(list.rangeSet);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/hydromatic/morel/ast/FromBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public FromBuilder addAll(Iterable<? extends Core.FromStep> steps) {

public FromBuilder where(Core.Exp condition) {
if (condition.op == Op.BOOL_LITERAL
&& (Boolean) ((Core.Literal) condition).value) {
&& ((Core.Literal) condition).unwrap(Boolean.class)) {
// skip "where true"
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ Code toRel4(Environment env, Code code, Type type) {
break;

case FN_LITERAL:
final BuiltIn builtIn = (BuiltIn) ((Core.Literal) apply.fn).value;
final BuiltIn builtIn =
((Core.Literal) apply.fn).unwrap(BuiltIn.class);
switch (builtIn) {
case Z_LIST:
final List<Core.Exp> args = apply.args();
Expand Down Expand Up @@ -497,7 +498,7 @@ record = toRecord(cx, id);
final Core.Apply apply = (Core.Apply) exp;
switch (apply.fn.op) {
case FN_LITERAL:
BuiltIn op = (BuiltIn) ((Core.Literal) apply.fn).value;
BuiltIn op = ((Core.Literal) apply.fn).unwrap(BuiltIn.class);

// Is it a unary operator with a Calcite equivalent? E.g. not => NOT
final SqlOperator unaryOp = UNARY_OPERATORS.get(op);
Expand Down Expand Up @@ -758,7 +759,7 @@ private RelContext group(RelContext cx, Core.Group group) {
* support aggregate functions defined by expressions (e.g. lambdas). */
@Nonnull private SqlAggFunction aggOp(Core.Exp aggregate) {
if (aggregate instanceof Core.Literal) {
switch ((BuiltIn) ((Core.Literal) aggregate).value) {
switch (((Core.Literal) aggregate).unwrap(BuiltIn.class)) {
case RELATIONAL_SUM:
case Z_SUM_INT:
case Z_SUM_REAL:
Expand Down
23 changes: 11 additions & 12 deletions src/main/java/net/hydromatic/morel/compile/Compiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.apache.calcite.util.Util;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -187,39 +186,39 @@ public Code compile(Context cx, Core.Exp expression) {
switch (expression.op) {
case BOOL_LITERAL:
literal = (Core.Literal) expression;
final Boolean boolValue = (Boolean) literal.value;
final Boolean boolValue = literal.unwrap(Boolean.class);
return Codes.constant(boolValue);

case CHAR_LITERAL:
literal = (Core.Literal) expression;
final Character charValue = (Character) literal.value;
final Character charValue = literal.unwrap(Character.class);
return Codes.constant(charValue);

case INT_LITERAL:
literal = (Core.Literal) expression;
return Codes.constant(((BigDecimal) literal.value).intValue());
return Codes.constant(literal.unwrap(Integer.class));

case REAL_LITERAL:
literal = (Core.Literal) expression;
return Codes.constant(((Number) literal.value).floatValue());
return Codes.constant(literal.unwrap(Float.class));

case STRING_LITERAL:
literal = (Core.Literal) expression;
final String stringValue = (String) literal.value;
final String stringValue = literal.unwrap(String.class);
return Codes.constant(stringValue);

case UNIT_LITERAL:
return Codes.constant(Unit.INSTANCE);

case FN_LITERAL:
literal = (Core.Literal) expression;
final BuiltIn builtIn = (BuiltIn) literal.value;
final BuiltIn builtIn = literal.unwrap(BuiltIn.class);
return Codes.constant(Codes.BUILT_IN_VALUES.get(builtIn));

case INTERNAL_LITERAL:
case VALUE_LITERAL:
literal = (Core.Literal) expression;
return Codes.constant(literal.unwrap());
return Codes.constant(literal.unwrap(Object.class));

case LET:
return compileLet(cx, (Core.Let) expression);
Expand Down Expand Up @@ -282,7 +281,7 @@ protected Code compileApply(Context cx, Core.Apply apply) {
// Is this is a call to a built-in operator?
switch (apply.fn.op) {
case FN_LITERAL:
final BuiltIn builtIn = (BuiltIn) ((Core.Literal) apply.fn).value;
final BuiltIn builtIn = ((Core.Literal) apply.fn).unwrap(BuiltIn.class);
return compileCall(cx, builtIn, apply.arg, apply.pos);
}
final Code argCode = compileArg(cx, apply.arg);
Expand Down Expand Up @@ -482,13 +481,13 @@ private Applicable compileApplicable(Context cx, Core.Exp fn, Type argType,
Pos pos) {
switch (fn.op) {
case FN_LITERAL:
final BuiltIn builtIn = (BuiltIn) ((Core.Literal) fn).value;
final BuiltIn builtIn = ((Core.Literal) fn).unwrap(BuiltIn.class);
final Object o = Codes.BUILT_IN_VALUES.get(builtIn);
return toApplicable(cx, o, argType, pos);

case VALUE_LITERAL:
final Core.Literal literal = (Core.Literal) fn;
return toApplicable(cx, literal.unwrap(), argType, pos);
return toApplicable(cx, literal.unwrap(Object.class), argType, pos);

case ID:
final Binding binding = cx.env.getOpt(((Core.Id) fn).idPat);
Expand Down Expand Up @@ -523,7 +522,7 @@ private Applicable compileApplicable(Context cx, Core.Exp fn, Type argType,
switch (exp.op) {
case FN_LITERAL:
final Core.Literal literal = (Core.Literal) exp;
final BuiltIn builtIn = (BuiltIn) literal.value;
final BuiltIn builtIn = literal.unwrap(BuiltIn.class);
return (Applicable) Codes.BUILT_IN_VALUES.get(builtIn);
}
final Code code = compile(cx, exp);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/hydromatic/morel/compile/Extents.java
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ void g3(Multimap<Core.Pat, Core.Exp> map, Core.Exp exp) {
apply = (Core.Apply) exp;
switch (apply.fn.op) {
case FN_LITERAL:
BuiltIn builtIn = (BuiltIn) ((Core.Literal) apply.fn).value;
BuiltIn builtIn = ((Core.Literal) apply.fn).unwrap(BuiltIn.class);
switch (builtIn) {
case Z_ANDALSO:
// Expression is 'andalso'. Visit each pattern, and union the
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/hydromatic/morel/compile/Inliner.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public static Inliner of(TypeSystem typeSystem, Environment env,
if (apply2.fn.op == Op.RECORD_SELECTOR
&& apply2.arg.op == Op.VALUE_LITERAL) {
final Core.RecordSelector selector = (Core.RecordSelector) apply2.fn;
final List list = (List) ((Core.Literal) apply2.arg).unwrap();
final List list = ((Core.Literal) apply2.arg).unwrap(List.class);
final Object o = list.get(selector.slot);
if (o instanceof Applicable || o instanceof Macro) {
// E.g. apply is '#filter List', o is Codes.LIST_FILTER,
Expand Down
11 changes: 7 additions & 4 deletions src/test/java/net/hydromatic/morel/compile/ExtentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.core.Is.is;

/**
Expand Down Expand Up @@ -137,14 +138,16 @@ Core.Literal intLiteral(int i) {
Core.Exp x = generator(f.typeSystem, xPat, exp);
assertThat(x, instanceOf(Core.Apply.class));
assertThat(((Core.Apply) x).fn, instanceOf(Core.Literal.class));
assertThat(((Core.Literal) ((Core.Apply) x).fn).value, is(BuiltIn.Z_EXTENT));
assertThat(x.toString(), is("extent \"int [[3..5), (5..10)]\""));
assertThat(((Core.Literal) ((Core.Apply) x).fn).unwrap(BuiltIn.class),
is(BuiltIn.Z_EXTENT));
assertThat(x, hasToString("extent \"int [[3..5), (5..10)]\""));

Core.Exp y = generator(f.typeSystem, yPat, exp);
assertThat(y, instanceOf(Core.Apply.class));
assertThat(((Core.Apply) y).fn, instanceOf(Core.Literal.class));
assertThat(((Core.Literal) ((Core.Apply) y).fn).value, is(BuiltIn.Z_LIST));
assertThat(y.toString(), is("[20]"));
assertThat(((Core.Literal) ((Core.Apply) y).fn).unwrap(BuiltIn.class),
is(BuiltIn.Z_LIST));
assertThat(y, hasToString("[20]"));
}

}
Expand Down

0 comments on commit 12353b8

Please sign in to comment.