diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index b4a2ec68afda0..7dd3b257474f3 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -6007,24 +6007,13 @@ mod tests { use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::FileScopeId; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; + use crate::types::check_types; use ruff_db::files::{system_path_to_file, File}; use ruff_db::system::DbWithTestSystem; use ruff_db::testing::assert_function_query_was_not_run; use super::*; - #[track_caller] - fn assert_public_type(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) { - let file = system_path_to_file(db, file_name).expect("file to exist"); - - let ty = global_symbol(db, file, symbol_name).expect_type(); - assert_eq!( - ty.display(db).to_string(), - expected, - "Mismatch for symbol '{symbol_name}' in '{file_name}'" - ); - } - #[track_caller] fn get_symbol<'db>( db: &'db TestDb, @@ -6049,26 +6038,41 @@ mod tests { symbol(db, scope, symbol_name) } + #[track_caller] + fn assert_diagnostic_messages(diagnostics: &TypeCheckDiagnostics, expected: &[&str]) { + let messages: Vec<&str> = diagnostics + .iter() + .map(|diagnostic| diagnostic.message()) + .collect(); + assert_eq!(&messages, expected); + } + + #[track_caller] + fn assert_file_diagnostics(db: &TestDb, filename: &str, expected: &[&str]) { + let file = system_path_to_file(db, filename).unwrap(); + let diagnostics = check_types(db, file); + + assert_diagnostic_messages(diagnostics, expected); + } + #[test] fn not_literal_string() -> anyhow::Result<()> { let mut db = setup_db(); let content = format!( r#" - v = not "{y}" - w = not 10*"{y}" - x = not "{y}"*10 - z = not 0*"{y}" - u = not (-100)*"{y}" - "#, + from typing_extensions import assert_type + + assert_type(not "{y}", bool) + assert_type(not 10*"{y}", bool) + assert_type(not "{y}"*10, bool) + assert_type(not 0*"{y}", Literal[True]) + assert_type(not (-100)*"{y}", Literal[True]) + "#, y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), ); db.write_dedented("src/a.py", &content)?; - assert_public_type(&db, "src/a.py", "v", "bool"); - assert_public_type(&db, "src/a.py", "w", "bool"); - assert_public_type(&db, "src/a.py", "x", "bool"); - assert_public_type(&db, "src/a.py", "z", "Literal[True]"); - assert_public_type(&db, "src/a.py", "u", "Literal[True]"); + assert_file_diagnostics(&db, "src/a.py", &[]); Ok(()) } @@ -6076,37 +6080,24 @@ mod tests { #[test] fn multiplied_string() -> anyhow::Result<()> { let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - &format!( - r#" - w = 2 * "hello" - x = "goodbye" * 3 - y = "a" * {y} - z = {z} * "b" - a = 0 * "hello" - b = -3 * "hello" + let content = format!( + r#" + from typing_extensions import assert_type + + assert_type(2 * "hello", Literal["hellohello"]) + assert_type("goodbye" * 3, Literal["goodbyegoodbyegoodbye"]) + assert_type("a" * {y}, Literal["{a_repeated}"]) + assert_type({z} * "b", LiteralString) + assert_type(0 * "hello", Literal[""]) + assert_type(-3 * "hello", Literal[""]) "#, - y = TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE, - z = TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1 - ), - )?; - - assert_public_type(&db, "src/a.py", "w", r#"Literal["hellohello"]"#); - assert_public_type(&db, "src/a.py", "x", r#"Literal["goodbyegoodbyegoodbye"]"#); - assert_public_type( - &db, - "src/a.py", - "y", - &format!( - r#"Literal["{}"]"#, - "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE) - ), + y = TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE, + z = TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1, + a_repeated = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE), ); - assert_public_type(&db, "src/a.py", "z", "LiteralString"); - assert_public_type(&db, "src/a.py", "a", r#"Literal[""]"#); - assert_public_type(&db, "src/a.py", "b", r#"Literal[""]"#); + db.write_dedented("src/a.py", &content)?; + + assert_file_diagnostics(&db, "src/a.py", &[]); Ok(()) } @@ -6116,21 +6107,20 @@ mod tests { let mut db = setup_db(); let content = format!( r#" - v = "{y}" - w = 10*"{y}" - x = "{y}"*10 - z = 0*"{y}" - u = (-100)*"{y}" - "#, + from typing_extensions import assert_type + + assert_type("{y}", LiteralString) + assert_type(10*"{y}", LiteralString) + assert_type("{y}"*10, LiteralString) + assert_type(0*"{y}", Literal[""]) + assert_type((-100)*"{y}", Literal[""]) + "#, y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), ); db.write_dedented("src/a.py", &content)?; - assert_public_type(&db, "src/a.py", "v", "LiteralString"); - assert_public_type(&db, "src/a.py", "w", "LiteralString"); - assert_public_type(&db, "src/a.py", "x", "LiteralString"); - assert_public_type(&db, "src/a.py", "z", r#"Literal[""]"#); - assert_public_type(&db, "src/a.py", "u", r#"Literal[""]"#); + assert_file_diagnostics(&db, "src/a.py", &[]); + Ok(()) } @@ -6139,16 +6129,17 @@ mod tests { let mut db = setup_db(); let content = format!( r#" - w = "{y}" - x = "a" + "{z}" - "#, + from typing_extensions import assert_type + + assert_type("{y}", LiteralString) + assert_type("a" + "{z}", LiteralString) + "#, y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), z = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE), ); db.write_dedented("src/a.py", &content)?; - assert_public_type(&db, "src/a.py", "w", "LiteralString"); - assert_public_type(&db, "src/a.py", "x", "LiteralString"); + assert_file_diagnostics(&db, "src/a.py", &[]); Ok(()) } @@ -6158,19 +6149,18 @@ mod tests { let mut db = setup_db(); let content = format!( r#" - v = "{y}" - w = "{y}" + "a" - x = "a" + "{y}" - z = "{y}" + "{y}" - "#, + from typing_extensions import assert_type + + assert_type("{y}", LiteralString) + assert_type("{y}" + "a", LiteralString) + assert_type("a" + "{y}", LiteralString) + assert_type("{y}" + "{y}", LiteralString) + "#, y = "a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1), ); db.write_dedented("src/a.py", &content)?; - assert_public_type(&db, "src/a.py", "v", "LiteralString"); - assert_public_type(&db, "src/a.py", "w", "LiteralString"); - assert_public_type(&db, "src/a.py", "x", "LiteralString"); - assert_public_type(&db, "src/a.py", "z", "LiteralString"); + assert_file_diagnostics(&db, "src/a.py", &[]); Ok(()) } @@ -6257,22 +6247,22 @@ mod tests { db.write_files([ ("/src/a.py", "from foo import x"), - ("/src/foo.py", "x = 10\ndef foo(): ..."), + ("/src/foo.py", "x: int = 10\ndef foo(): ..."), ])?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty = global_symbol(&db, a, "x").expect_type(); - assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); + assert_eq!(x_ty.display(&db).to_string(), "int"); // Change `x` to a different value - db.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?; + db.write_file("/src/foo.py", "x: bool = True\ndef foo(): ...")?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty_2 = global_symbol(&db, a, "x").expect_type(); - assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]"); + assert_eq!(x_ty_2.display(&db).to_string(), "bool"); Ok(()) } @@ -6283,15 +6273,15 @@ mod tests { db.write_files([ ("/src/a.py", "from foo import x"), - ("/src/foo.py", "x = 10\ndef foo(): y = 1"), + ("/src/foo.py", "x: int = 10\ndef foo(): y = 1"), ])?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty = global_symbol(&db, a, "x").expect_type(); - assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); + assert_eq!(x_ty.display(&db).to_string(), "int"); - db.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?; + db.write_file("/src/foo.py", "x: int = 10\ndef foo(): pass")?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); @@ -6299,7 +6289,7 @@ mod tests { let x_ty_2 = global_symbol(&db, a, "x").expect_type(); - assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); + assert_eq!(x_ty_2.display(&db).to_string(), "int"); let events = db.take_salsa_events(); @@ -6319,15 +6309,15 @@ mod tests { db.write_files([ ("/src/a.py", "from foo import x"), - ("/src/foo.py", "x = 10\ny = 20"), + ("/src/foo.py", "x: int = 10\ny: bool = True"), ])?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty = global_symbol(&db, a, "x").expect_type(); - assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); + assert_eq!(x_ty.display(&db).to_string(), "int"); - db.write_file("/src/foo.py", "x = 10\ny = 30")?; + db.write_file("/src/foo.py", "x: int = 10\ny: bool = False")?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); @@ -6335,7 +6325,7 @@ mod tests { let x_ty_2 = global_symbol(&db, a, "x").expect_type(); - assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); + assert_eq!(x_ty_2.display(&db).to_string(), "int"); let events = db.take_salsa_events();