Skip to content

Commit

Permalink
Per review
Browse files Browse the repository at this point in the history
  • Loading branch information
InSyncWithFoo committed Jan 9, 2025
1 parent 41f6284 commit b5d8b8f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def _(a: type[int], b: type[Any]):

# The expression constructing the type is not taken into account
def _(a: type[int]):
# TODO: Infer the second argument as a type expression
assert_type(a, Type[int]) # error: [type-assertion-failure]
assert_type(a, Type[int]) # fine
```

## Gradual types
Expand All @@ -69,14 +68,12 @@ def _(a: Unknown, b: Any):
def _(a: type[Unknown], b: type[Any]):
# TODO: Should be `type[Unknown]`
reveal_type(a) # revealed: @Todo(unsupported type[X] special form)
reveal_type(b) # revealed: type[Any]
# TODO: Should be fine
assert_type(a, type[Any]) # error: [type-assertion-failure]

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, type[Unknown]) # error: [type-assertion-failure]
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(b, type[Any]) # error: [type-assertion-failure]
reveal_type(b) # revealed: type[Any]
# TODO: Should be fine
assert_type(b, type[Unknown]) # error: [type-assertion-failure]
```

## Tuples
Expand All @@ -86,19 +83,21 @@ Tuple types with the same elements are the same.
```py
from typing_extensions import assert_type

from knot_extensions import Unknown

def _(a: tuple[int, str, bytes]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes]) # fine

assert_type(a, tuple[int, str]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure]
assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure]

def _(a: tuple[Any, ...]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure]
def _(a: tuple[Any, ...], b: tuple[Unknown, ...]):
assert_type(a, tuple[Any, ...]) # fine
assert_type(a, tuple[Unknown, ...]) # fine

assert_type(b, tuple[Any, ...]) # fine
assert_type(b, tuple[Unknown, ...]) # fine
```

## Unions
Expand Down
39 changes: 25 additions & 14 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,14 @@ impl<'db> Type<'db> {
CallOutcome::callable(binding)
}

Some(KnownFunction::AssertType) => {
let Some((_, asserted_ty)) = binding.two_parameter_tys() else {
return CallOutcome::callable(binding);
};

CallOutcome::asserted(binding, asserted_ty)
}

_ => CallOutcome::callable(binding),
}
}
Expand Down Expand Up @@ -3401,17 +3409,23 @@ impl KnownFunction {
/// Whether or not a particular function takes type expression as arguments, i.e. should
/// the argument of a call like `f(int)` be interpreted as the type int (true) or as the
/// type of the expression `int`, i.e. `Literal[int]` (false).
const fn takes_type_expression_arguments(self) -> bool {
matches!(
self,
KnownFunction::IsEquivalentTo
| KnownFunction::IsSubtypeOf
| KnownFunction::IsAssignableTo
| KnownFunction::IsDisjointFrom
| KnownFunction::IsFullyStatic
| KnownFunction::IsSingleton
| KnownFunction::IsSingleValued
)
const fn takes_type_expression_arguments(self) -> u32 {
const ALL_VALUES: u32 = 0b0;
const SINGLE_TYPE: u32 = 0b1;
const TYPE_TYPE: u32 = 0b11;
const VALUE_TYPE: u32 = 0b10;

match self {
KnownFunction::IsEquivalentTo => TYPE_TYPE,
KnownFunction::IsSubtypeOf => TYPE_TYPE,
KnownFunction::IsAssignableTo => TYPE_TYPE,
KnownFunction::IsDisjointFrom => TYPE_TYPE,
KnownFunction::IsFullyStatic => SINGLE_TYPE,
KnownFunction::IsSingleton => SINGLE_TYPE,
KnownFunction::IsSingleValued => SINGLE_TYPE,
KnownFunction::AssertType => VALUE_TYPE,
_ => ALL_VALUES,
}
}
}

Expand Down Expand Up @@ -4712,7 +4726,6 @@ pub(crate) mod tests {
#[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))]
#[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)]
#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfBuiltinClass("object"))]
// TODO: Compare unions/intersections with different orders
// #[test_case(
Expand Down Expand Up @@ -4749,9 +4762,7 @@ pub(crate) mod tests {
}

#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfAny)]
#[test_case(Ty::BuiltinInstance("type"), Ty::SubclassOfUnknown)]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfAny)]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::SubclassOfUnknown)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")])
Expand Down
25 changes: 15 additions & 10 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_type_parameters(type_params);

if let Some(arguments) = class.arguments.as_deref() {
self.infer_arguments(arguments, false);
self.infer_arguments(arguments, 0b0);
}
}

Expand Down Expand Up @@ -2539,17 +2539,21 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_arguments<'a>(
&mut self,
arguments: &'a ast::Arguments,
infer_as_type_expressions: bool,
infer_as_type_expressions: u32,
) -> CallArguments<'a, 'db> {
let infer_argument_type = if infer_as_type_expressions {
Self::infer_type_expression
} else {
Self::infer_expression
};

arguments
.arguments_source_order()
.map(|arg_or_keyword| {
.enumerate()
.map(|(index, arg_or_keyword)| {
// TODO: Remove this once we have proper overload matching
let infer_argument_type = if index < u32::BITS as usize
&& infer_as_type_expressions & (1 << index) != 0
{
Self::infer_type_expression
} else {
Self::infer_expression
};

match arg_or_keyword {
ast::ArgOrKeyword::Arg(arg) => match arg {
ast::Expr::Starred(ast::ExprStarred {
Expand Down Expand Up @@ -3095,7 +3099,8 @@ impl<'db> TypeInferenceBuilder<'db> {
let infer_arguments_as_type_expressions = function_type
.into_function_literal()
.and_then(|f| f.known(self.db()))
.is_some_and(KnownFunction::takes_type_expression_arguments);
.map(KnownFunction::takes_type_expression_arguments)
.unwrap_or(0b0);

let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions);
function_type
Expand Down

0 comments on commit b5d8b8f

Please sign in to comment.