diff --git a/Cargo.lock b/Cargo.lock index d463cd49..fcdf85c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -464,6 +464,15 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atty" version = "0.2.14" @@ -808,6 +817,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -828,6 +843,9 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] [[package]] name = "blake2" @@ -1108,6 +1126,15 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.8" @@ -1121,6 +1148,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -1181,6 +1214,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32fast" version = "1.4.2" @@ -1594,6 +1642,17 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "der" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1623,6 +1682,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -1685,6 +1745,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "educe" version = "0.6.0" @@ -1702,6 +1768,9 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -1828,6 +1897,28 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + +[[package]] +name = "event-listener" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1895,6 +1986,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1961,6 +2063,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -2113,6 +2226,15 @@ version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.4.1" @@ -2152,6 +2274,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -2161,6 +2292,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" version = "0.2.12" @@ -2529,6 +2669,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin 0.9.8", +] [[package]] name = "lexical-core" @@ -2650,6 +2793,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "libtest-mimic" version = "0.7.3" @@ -2751,6 +2905,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.0" @@ -2798,6 +2958,16 @@ dependencies = [ "libc", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -2838,6 +3008,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -3055,6 +3242,24 @@ dependencies = [ "serde_with", ] +[[package]] +name = "optd-ng-kernel" +version = "0.1.1" +dependencies = [ + "anyhow", + "arrow-schema", + "async-recursion", + "async-trait", + "chrono", + "itertools 0.13.0", + "ordered-float 4.5.0", + "serde", + "serde_json", + "sqlx", + "tokio", + "tracing", +] + [[package]] name = "optd-perfbench" version = "0.1.0" @@ -3248,6 +3453,15 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -3334,6 +3548,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.31" @@ -3702,6 +3937,26 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" +[[package]] +name = "rsa" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47c75d7c5c6b673e58bf54d8544a9f432e3a925b0e80f7cd3602ab5c50c55519" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rstest" version = "0.17.0" @@ -4002,6 +4257,17 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -4028,6 +4294,25 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "similar" version = "2.6.0" @@ -4054,6 +4339,9 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +dependencies = [ + "serde", +] [[package]] name = "snafu" @@ -4104,6 +4392,29 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "sqlformat" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" +dependencies = [ + "nom", + "unicode_categories", +] [[package]] name = "sqllogictest" @@ -4168,6 +4479,200 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlx" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93334716a037193fac19df402f8571269c84a00852f6a7066b5d2616dcd64d3e" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" +dependencies = [ + "atoi", + "byteorder", + "bytes", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-channel", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.14.5", + "hashlink", + "hex", + "indexmap 2.6.0", + "log", + "memchr", + "once_cell", + "paste", + "percent-encoding", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlformat", + "thiserror 1.0.68", + "tokio", + "tokio-stream", + "tracing", + "url", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cac0692bcc9de3b073e8d747391827297e075c7710ff6276d9f7a1f3d58c6657" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.87", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" +dependencies = [ + "dotenvy", + "either", + "heck 0.5.0", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 2.0.87", + "tempfile", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "bytes", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 1.0.68", + "tracing", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 1.0.68", + "tracing", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "serde_urlencoded", + "sqlx-core", + "tracing", + "url", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -4551,6 +5056,7 @@ dependencies = [ "mio", "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -4811,6 +5317,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -4880,6 +5392,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/Cargo.toml b/Cargo.toml index 44a4b5be..061acd19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "optd-perfbench", "optd-datafusion-repr-adv-cost", "optd-sqllogictest", + "optd-ng-kernel", ] resolver = "2" diff --git a/optd-ng-kernel/Cargo.toml b/optd-ng-kernel/Cargo.toml new file mode 100644 index 00000000..6e2f77bc --- /dev/null +++ b/optd-ng-kernel/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "optd-ng-kernel" +version.workspace = true +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +anyhow = "1" +async-recursion = "1" +async-trait = "0.1" +arrow-schema = "47.0.0" +tracing = "0.1" +ordered-float = "4" +itertools = "0.13" +serde = { version = "1.0", features = ["derive", "rc"] } +chrono = "0.4" +sqlx = { version = "0.8", features = [ + "runtime-tokio", + "sqlite", +] } # TODO: strip the features, move to another crate +serde_json = { version = "1" } # TODO: move to another crate + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } diff --git a/optd-ng-kernel/src/cascades.rs b/optd-ng-kernel/src/cascades.rs new file mode 100644 index 00000000..e9de8ce6 --- /dev/null +++ b/optd-ng-kernel/src/cascades.rs @@ -0,0 +1,9 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +pub mod memo; +pub mod naive_memo; +pub mod optimizer; +pub mod persistent_memo; diff --git a/optd-ng-kernel/src/cascades/memo.rs b/optd-ng-kernel/src/cascades/memo.rs new file mode 100644 index 00000000..023e5cdb --- /dev/null +++ b/optd-ng-kernel/src/cascades/memo.rs @@ -0,0 +1,121 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +use std::collections::HashSet; +use std::sync::Arc; + +use super::optimizer::{ExprId, GroupId, PredId}; +use crate::nodes::{ArcPlanNode, ArcPredNode, NodeType, PlanNodeOrGroup}; +use async_trait::async_trait; + +pub type ArcMemoPlanNode = Arc>; + +/// The RelNode representation in the memo table. Store children as group IDs. Equivalent to MExpr +/// in Columbia/Cascades. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct MemoPlanNode { + pub typ: T, + pub children: Vec, + pub predicates: Vec, +} + +impl std::fmt::Display for MemoPlanNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}", self.typ)?; + for child in &self.children { + write!(f, " {}", child)?; + } + for pred in &self.predicates { + write!(f, " {}", pred)?; + } + write!(f, ")") + } +} + +#[derive(Clone)] +pub struct WinnerInfo {/* unimplemented */} + +#[derive(Clone)] +pub enum Winner { + Unknown, + Impossible, + Full(WinnerInfo), +} + +impl Winner { + pub fn has_full_winner(&self) -> bool { + matches!(self, Self::Full { .. }) + } + + pub fn has_decided(&self) -> bool { + matches!(self, Self::Full { .. } | Self::Impossible) + } + + pub fn as_full_winner(&self) -> Option<&WinnerInfo> { + match self { + Self::Full(info) => Some(info), + _ => None, + } + } +} + +impl Default for Winner { + fn default() -> Self { + Self::Unknown + } +} + +#[derive(Default, Clone)] +pub struct GroupInfo { + pub winner: Winner, +} + +pub struct Group { + pub(crate) group_exprs: HashSet, + pub(crate) info: GroupInfo, +} + +/// Trait for memo table implementations. TODO: use GAT in the future. +#[async_trait] +pub trait Memo: 'static + Send + Sync { + /// Add an expression to the memo table. If the expression already exists, it will return the + /// existing group id and expr id. Otherwise, a new group and expr will be created. + async fn add_new_expr(&mut self, rel_node: ArcPlanNode) -> (GroupId, ExprId); + + /// Add a new expression to an existing gruop. If the expression is a group, it will merge the + /// two groups. Otherwise, it will add the expression to the group. Returns the expr id if + /// the expression is not a group. + async fn add_expr_to_group( + &mut self, + rel_node: PlanNodeOrGroup, + group_id: GroupId, + ) -> Option; + + /// Add a new predicate into the memo table. + async fn add_new_pred(&mut self, pred_node: ArcPredNode) -> PredId; + + /// Get the group id of an expression. + /// The group id is volatile, depending on whether the groups are merged. + async fn get_group_id(&self, expr_id: ExprId) -> GroupId; + + /// Get the memoized representation of a node. + async fn get_expr_memoed(&self, expr_id: ExprId) -> ArcMemoPlanNode; + + /// Get all groups IDs in the memo table. + async fn get_all_group_ids(&self) -> Vec; + + /// Get a predicate by ID + async fn get_pred(&self, pred_id: PredId) -> ArcPredNode; + + /// Estimated plan space for the memo table, only useful when plan exploration budget is + /// enabled. Returns number of expressions in the memo table. + async fn estimated_plan_space(&self) -> usize; + + // The below functions can be overwritten by the memo table implementation if there + // are more efficient way to retrieve the information. + + /// Get all expressions in the group. + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec; +} diff --git a/optd-ng-kernel/src/cascades/naive_memo.rs b/optd-ng-kernel/src/cascades/naive_memo.rs new file mode 100644 index 00000000..a8271267 --- /dev/null +++ b/optd-ng-kernel/src/cascades/naive_memo.rs @@ -0,0 +1,559 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use async_trait::async_trait; +use itertools::Itertools; +use tracing::trace; + +use super::memo::{ArcMemoPlanNode, Group, Memo, MemoPlanNode, Winner}; +use super::optimizer::{ExprId, GroupId, PredId}; +use crate::cascades::memo::GroupInfo; +use crate::nodes::{ArcPlanNode, ArcPredNode, NodeType, PlanNodeOrGroup}; + +/// A naive, simple, and unoptimized memo table implementation. +pub struct NaiveMemo { + // Source of truth. + groups: HashMap, + expr_id_to_expr_node: HashMap>, + + // Predicate stuff. + pred_id_to_pred_node: HashMap>, + pred_node_to_pred_id: HashMap, PredId>, + + // Internal states. + group_expr_counter: usize, + + // Indexes. + expr_node_to_expr_id: HashMap, ExprId>, + expr_id_to_group_id: HashMap, + + // We update all group IDs in the memo table upon group merging, but + // there might be edge cases that some tasks still hold the old group ID. + // In this case, we need this mapping to redirect to the merged group ID. + merged_group_mapping: HashMap, + dup_expr_mapping: HashMap, +} + +#[async_trait] +impl Memo for NaiveMemo { + async fn add_new_expr(&mut self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { + self.add_new_expr_inner(rel_node) + } + + async fn add_expr_to_group( + &mut self, + rel_node: PlanNodeOrGroup, + group_id: GroupId, + ) -> Option { + self.add_expr_to_group_inner(rel_node, group_id) + } + + async fn add_new_pred(&mut self, pred_node: ArcPredNode) -> PredId { + self.add_new_pred_inner(pred_node) + } + + async fn get_pred(&self, pred_id: PredId) -> ArcPredNode { + self.get_pred_inner(pred_id) + } + + async fn get_group_id(&self, expr_id: ExprId) -> GroupId { + self.get_group_id_inner(expr_id) + } + + async fn get_expr_memoed(&self, expr_id: ExprId) -> ArcMemoPlanNode { + self.get_expr_memoed_inner(expr_id) + } + + async fn get_all_group_ids(&self) -> Vec { + self.get_all_group_ids_inner() + } + + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { + let mut expr_ids: Vec = self + .get_group_inner(group_id) + .group_exprs + .iter() + .copied() + .collect(); + expr_ids.sort(); + expr_ids + } + + async fn estimated_plan_space(&self) -> usize { + self.expr_id_to_expr_node.len() + } +} + +impl NaiveMemo { + pub fn new() -> Self { + Self { + expr_id_to_group_id: HashMap::new(), + expr_id_to_expr_node: HashMap::new(), + expr_node_to_expr_id: HashMap::new(), + pred_id_to_pred_node: HashMap::new(), + pred_node_to_pred_id: HashMap::new(), + groups: HashMap::new(), + group_expr_counter: 0, + merged_group_mapping: HashMap::new(), + dup_expr_mapping: HashMap::new(), + } + } + + fn add_new_expr_inner(&mut self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { + let (group_id, expr_id) = self + .add_new_group_expr_inner(rel_node, None) + .expect("should not trigger merge group"); + self.verify_integrity(); + (group_id, expr_id) + } + + fn add_expr_to_group_inner( + &mut self, + rel_node: PlanNodeOrGroup, + group_id: GroupId, + ) -> Option { + match rel_node { + PlanNodeOrGroup::Group(input_group) => { + let input_group = self.reduce_group(input_group); + let group_id = self.reduce_group(group_id); + self.merge_group_inner(input_group, group_id); + None + } + PlanNodeOrGroup::PlanNode(rel_node) => { + let reduced_group_id = self.reduce_group(group_id); + let (returned_group_id, expr_id) = self + .add_new_group_expr_inner(rel_node, Some(reduced_group_id)) + .unwrap(); + assert_eq!(returned_group_id, reduced_group_id); + self.verify_integrity(); + Some(expr_id) + } + } + } + + fn add_new_pred_inner(&mut self, pred_node: ArcPredNode) -> PredId { + let pred_id = self.next_pred_id(); + if let Some(id) = self.pred_node_to_pred_id.get(&pred_node) { + return *id; + } + self.pred_node_to_pred_id.insert(pred_node.clone(), pred_id); + self.pred_id_to_pred_node.insert(pred_id, pred_node); + pred_id + } + + fn get_pred_inner(&self, pred_id: PredId) -> ArcPredNode { + self.pred_id_to_pred_node[&pred_id].clone() + } + + fn get_group_id_inner(&self, mut expr_id: ExprId) -> GroupId { + while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) { + expr_id = *new_expr_id; + } + *self + .expr_id_to_group_id + .get(&expr_id) + .expect("expr not found in group mapping") + } + + fn get_expr_memoed_inner(&self, mut expr_id: ExprId) -> ArcMemoPlanNode { + while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) { + expr_id = *new_expr_id; + } + self.expr_id_to_expr_node + .get(&expr_id) + .expect("expr not found in expr mapping") + .clone() + } + + fn get_all_group_ids_inner(&self) -> Vec { + let mut ids = self.groups.keys().copied().collect_vec(); + ids.sort(); + ids + } + + fn get_group_inner(&self, group_id: GroupId) -> &Group { + let group_id = self.reduce_group(group_id); + self.groups.get(&group_id).as_ref().unwrap() + } + + /// Get the next group id. Group id and expr id shares the same counter, so as to make it easier + /// to debug... + fn next_group_id(&mut self) -> GroupId { + let id = self.group_expr_counter; + self.group_expr_counter += 1; + GroupId(id) + } + + /// Get the next expr id. Group id and expr id shares the same counter, so as to make it easier + /// to debug... + fn next_expr_id(&mut self) -> ExprId { + let id = self.group_expr_counter; + self.group_expr_counter += 1; + ExprId(id) + } + + /// Get the next pred id. Group id and expr id shares the same counter, so as to make it easier + /// to debug... + fn next_pred_id(&mut self) -> PredId { + let id = self.group_expr_counter; + self.group_expr_counter += 1; + PredId(id) + } + + fn verify_integrity(&self) { + if cfg!(debug_assertions) { + let num_of_exprs = self.expr_id_to_expr_node.len(); + assert_eq!(num_of_exprs, self.expr_node_to_expr_id.len()); + assert_eq!(num_of_exprs, self.expr_id_to_group_id.len()); + + let mut valid_groups = HashSet::new(); + for to in self.merged_group_mapping.values() { + assert_eq!(self.merged_group_mapping[to], *to); + valid_groups.insert(*to); + } + assert_eq!(valid_groups.len(), self.groups.len()); + + for (id, node) in self.expr_id_to_expr_node.iter() { + assert_eq!(self.expr_node_to_expr_id[node], *id); + for child in &node.children { + assert!( + valid_groups.contains(child), + "invalid group used in expression {}, where {} does not exist any more", + node, + child + ); + } + } + + let mut cnt = 0; + for (group_id, group) in &self.groups { + assert!(valid_groups.contains(group_id)); + cnt += group.group_exprs.len(); + assert!(!group.group_exprs.is_empty()); + for expr in &group.group_exprs { + assert_eq!(self.expr_id_to_group_id[expr], *group_id); + } + } + assert_eq!(cnt, num_of_exprs); + } + } + + fn reduce_group(&self, group_id: GroupId) -> GroupId { + self.merged_group_mapping[&group_id] + } + + fn merge_group_inner(&mut self, merge_into: GroupId, merge_from: GroupId) { + if merge_into == merge_from { + return; + } + trace!(event = "merge_group", merge_into = %merge_into, merge_from = %merge_from); + let group_merge_from = self.groups.remove(&merge_from).unwrap(); + let group_merge_into = self.groups.get_mut(&merge_into).unwrap(); + // TODO: update winner, cost and properties + for from_expr in group_merge_from.group_exprs { + let ret = self.expr_id_to_group_id.insert(from_expr, merge_into); + assert!(ret.is_some()); + group_merge_into.group_exprs.insert(from_expr); + } + self.merged_group_mapping.insert(merge_from, merge_into); + + // Update all indexes and other data structures + // 1. update merged group mapping -- could be optimized with union find + for (_, mapped_to) in self.merged_group_mapping.iter_mut() { + if *mapped_to == merge_from { + *mapped_to = merge_into; + } + } + + let mut pending_recursive_merge = Vec::new(); + // 2. update all group expressions and indexes + for (group_id, group) in self.groups.iter_mut() { + let mut new_expr_list = HashSet::new(); + for expr_id in group.group_exprs.iter() { + let expr = self.expr_id_to_expr_node[expr_id].clone(); + if expr.children.contains(&merge_from) { + // Create the new expr node + let old_expr = expr.as_ref().clone(); + let mut new_expr = expr.as_ref().clone(); + new_expr.children.iter_mut().for_each(|x| { + if *x == merge_from { + *x = merge_into; + } + }); + // Update all existing entries and indexes + self.expr_id_to_expr_node + .insert(*expr_id, Arc::new(new_expr.clone())); + self.expr_node_to_expr_id.remove(&old_expr); + if let Some(dup_expr) = self.expr_node_to_expr_id.get(&new_expr) { + // If new_expr == some_other_old_expr in the memo table, unless they belong + // to the same group, we should merge the two + // groups. This should not happen. We should simply drop this expression. + let dup_group_id = self.expr_id_to_group_id[dup_expr]; + if dup_group_id != *group_id { + pending_recursive_merge.push((dup_group_id, *group_id)); + } + self.expr_id_to_expr_node.remove(expr_id); + self.expr_id_to_group_id.remove(expr_id); + self.dup_expr_mapping.insert(*expr_id, *dup_expr); + new_expr_list.insert(*dup_expr); // adding this temporarily -- should be + // removed once recursive merge finishes + } else { + self.expr_node_to_expr_id.insert(new_expr, *expr_id); + new_expr_list.insert(*expr_id); + } + } else { + new_expr_list.insert(*expr_id); + } + } + assert!(!new_expr_list.is_empty()); + group.group_exprs = new_expr_list; + } + for (merge_from, merge_into) in pending_recursive_merge { + // We need to reduce because each merge would probably invalidate some groups in the + // last loop iteration. + let merge_from = self.reduce_group(merge_from); + let merge_into = self.reduce_group(merge_into); + self.merge_group_inner(merge_into, merge_from); + } + } + + fn add_new_group_expr_inner( + &mut self, + rel_node: ArcPlanNode, + add_to_group_id: Option, + ) -> anyhow::Result<(GroupId, ExprId)> { + let children_group_ids = rel_node + .children + .iter() + .map(|child| { + match child { + // TODO: can I remove reduce? + PlanNodeOrGroup::Group(group) => self.reduce_group(*group), + PlanNodeOrGroup::PlanNode(child) => { + // No merge / modification to the memo should occur for the following + // operation + let (group, _) = self + .add_new_group_expr_inner(child.clone(), None) + .expect("should not trigger merge group"); + self.reduce_group(group) // TODO: can I remove? + } + } + }) + .collect::>(); + let memo_node = MemoPlanNode { + typ: rel_node.typ.clone(), + children: children_group_ids, + predicates: rel_node + .predicates + .iter() + .map(|x| self.add_new_pred_inner(x.clone())) + .collect(), + }; + if let Some(&expr_id) = self.expr_node_to_expr_id.get(&memo_node) { + let group_id = self.expr_id_to_group_id[&expr_id]; + if let Some(add_to_group_id) = add_to_group_id { + let add_to_group_id = self.reduce_group(add_to_group_id); + self.merge_group_inner(add_to_group_id, group_id); + return Ok((add_to_group_id, expr_id)); + } + return Ok((group_id, expr_id)); + } + let expr_id = self.next_expr_id(); + let group_id = if let Some(group_id) = add_to_group_id { + group_id + } else { + self.next_group_id() + }; + self.expr_id_to_expr_node + .insert(expr_id, memo_node.clone().into()); + self.expr_id_to_group_id.insert(expr_id, group_id); + self.expr_node_to_expr_id.insert(memo_node.clone(), expr_id); + self.append_expr_to_group(expr_id, group_id, memo_node); + Ok((group_id, expr_id)) + } + + /// This is inefficient: usually the optimizer should have a MemoRef instead of passing the full + /// rel node. Should be only used for debugging purpose. + #[cfg(test)] + pub(crate) fn get_expr_info(&self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { + let children_group_ids = rel_node + .children + .iter() + .map(|child| match child { + PlanNodeOrGroup::Group(group) => *group, + PlanNodeOrGroup::PlanNode(child) => self.get_expr_info(child.clone()).0, + }) + .collect::>(); + let memo_node = MemoPlanNode { + typ: rel_node.typ.clone(), + children: children_group_ids, + predicates: rel_node + .predicates + .iter() + .map(|x| self.pred_node_to_pred_id[x]) + .collect(), + }; + let Some(&expr_id) = self.expr_node_to_expr_id.get(&memo_node) else { + unreachable!("not found {}", memo_node) + }; + let group_id = self.expr_id_to_group_id[&expr_id]; + (group_id, expr_id) + } + + /// If group_id exists, it adds expr_id to the existing group + /// Otherwise, it creates a new group of that group_id and insert expr_id into the new group + fn append_expr_to_group( + &mut self, + expr_id: ExprId, + group_id: GroupId, + memo_node: MemoPlanNode, + ) { + trace!(event = "add_expr_to_group", group_id = %group_id, expr_id = %expr_id, memo_node = %memo_node); + if let Entry::Occupied(mut entry) = self.groups.entry(group_id) { + let group = entry.get_mut(); + group.group_exprs.insert(expr_id); + return; + } + // Create group and infer properties (only upon initializing a group). + let mut group = Group { + group_exprs: HashSet::new(), + info: GroupInfo::default(), + }; + group.group_exprs.insert(expr_id); + self.groups.insert(group_id, group); + self.merged_group_mapping.insert(group_id, group_id); + } + + pub fn clear_winner(&mut self) { + for group in self.groups.values_mut() { + group.info.winner = Winner::Unknown; + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::{ + nodes::Value, + tests::common::{expr, group, join, list, project, scan, MemoTestRelTyp}, + }; + + #[tokio::test] + async fn add_predicate() { + let mut memo = NaiveMemo::::new(); + let pred_node = list(vec![expr(Value::Int32(233))]); + let p1 = memo.add_new_pred(pred_node.clone()).await; + let p2 = memo.add_new_pred(pred_node.clone()).await; + assert_eq!(p1, p2); + } + + #[tokio::test] + async fn group_merge_1() { + let mut memo = NaiveMemo::new(); + let (group_id, _) = memo + .add_new_expr(join(scan("t1"), scan("t2"), expr(Value::Bool(true)))) + .await; + memo.add_expr_to_group( + join(scan("t2"), scan("t1"), expr(Value::Bool(true))).into(), + group_id, + ) + .await; + assert_eq!(memo.get_all_exprs_in_group(group_id).await.len(), 2); + } + + #[tokio::test] + async fn group_merge_2() { + let mut memo = NaiveMemo::new(); + let (group_id_1, _) = memo + .add_new_expr(project( + join(scan("t1"), scan("t2"), expr(Value::Bool(true))), + list(vec![expr(Value::Int64(1))]), + )) + .await; + let (group_id_2, _) = memo + .add_new_expr(project( + join(scan("t1"), scan("t2"), expr(Value::Bool(true))), + list(vec![expr(Value::Int64(1))]), + )) + .await; + assert_eq!(group_id_1, group_id_2); + } + + #[tokio::test] + async fn group_merge_3() { + let mut memo = NaiveMemo::new(); + let expr1 = project(scan("t1"), list(vec![expr(Value::Int64(1))])); + let expr2 = project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])); + memo.add_new_expr(expr1.clone()).await; + memo.add_new_expr(expr2.clone()).await; + // merging two child groups causes parent to merge + let (group_id_expr, _) = memo.get_expr_info(scan("t1")); + memo.add_expr_to_group(scan("t1-alias").into(), group_id_expr) + .await; + let (group_1, _) = memo.get_expr_info(expr1); + let (group_2, _) = memo.get_expr_info(expr2); + assert_eq!(group_1, group_2); + } + + #[tokio::test] + async fn group_merge_4() { + let mut memo = NaiveMemo::new(); + let expr1 = project( + project(scan("t1"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + let expr2 = project( + project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + memo.add_new_expr(expr1.clone()).await; + memo.add_new_expr(expr2.clone()).await; + // merge two child groups, cascading merge + let (group_id_expr, _) = memo.get_expr_info(scan("t1")); + memo.add_expr_to_group(scan("t1-alias").into(), group_id_expr) + .await; + let (group_1, _) = memo.get_expr_info(expr1.clone()); + let (group_2, _) = memo.get_expr_info(expr2.clone()); + assert_eq!(group_1, group_2); + let (group_1, _) = memo.get_expr_info(expr1.child_rel(0)); + let (group_2, _) = memo.get_expr_info(expr2.child_rel(0)); + assert_eq!(group_1, group_2); + } + + #[tokio::test] + async fn group_merge_5() { + let mut memo = NaiveMemo::new(); + let expr1 = project( + project(scan("t1"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + let expr2 = project( + project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + let (_, expr1_id) = memo.add_new_expr(expr1.clone()).await; + let (_, expr2_id) = memo.add_new_expr(expr2.clone()).await; + + // experimenting with group id in expr (i.e., when apply rules) + let (scan_t1, _) = memo.get_expr_info(scan("t1")); + let pred = list(vec![expr(Value::Int64(1))]); + let proj_binding = project(group(scan_t1), pred); + let middle_proj_2 = memo.get_expr_memoed(expr2_id).await.children[0]; + + memo.add_expr_to_group(proj_binding.into(), middle_proj_2) + .await; + + assert_eq!( + memo.get_expr_memoed(expr1_id).await, + memo.get_expr_memoed(expr2_id).await + ); // these two expressions are merged + assert_eq!(memo.get_expr_info(expr1), memo.get_expr_info(expr2)); + } +} diff --git a/optd-ng-kernel/src/cascades/optimizer.rs b/optd-ng-kernel/src/cascades/optimizer.rs new file mode 100644 index 00000000..8bfcd00e --- /dev/null +++ b/optd-ng-kernel/src/cascades/optimizer.rs @@ -0,0 +1,33 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +use std::fmt::Display; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] +pub struct GroupId(pub(super) usize); + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] +pub struct ExprId(pub usize); + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] +pub struct PredId(pub usize); + +impl Display for GroupId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "!{}", self.0) + } +} + +impl Display for ExprId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Display for PredId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "P{}", self.0) + } +} diff --git a/optd-ng-kernel/src/cascades/persistent_memo.rs b/optd-ng-kernel/src/cascades/persistent_memo.rs new file mode 100644 index 00000000..6c48cc2c --- /dev/null +++ b/optd-ng-kernel/src/cascades/persistent_memo.rs @@ -0,0 +1,555 @@ +use std::marker::PhantomData; + +use async_recursion::async_recursion; +use async_trait::async_trait; +use sqlx::{Row, SqlitePool}; + +use crate::nodes::{ArcPlanNode, ArcPredNode, PersistentNodeType, PlanNodeOrGroup}; + +use super::{ + memo::{ArcMemoPlanNode, Memo, MemoPlanNode}, + optimizer::{ExprId, GroupId, PredId}, +}; + +/// A persistent memo table implementation. +pub struct PersistentMemo { + db_conn: SqlitePool, // TODO: make this a generic + _phantom: std::marker::PhantomData, +} + +impl PersistentMemo { + pub async fn new(db_conn: SqlitePool) -> Self { + Self { + db_conn, + _phantom: PhantomData, + } + } + + pub async fn setup(&mut self) -> anyhow::Result<()> { + // TODO: use migration + sqlx::query("CREATE TABLE groups(group_id INTEGER PRIMARY KEY AUTOINCREMENT)") + .execute(&self.db_conn) + .await?; + sqlx::query("CREATE TABLE group_merges(from_group_id INTEGER PRIMARY KEY AUTOINCREMENT, to_group_id INTEGER)") + .execute(&self.db_conn) + .await?; + // Ideally, tag should be an enum, and we should populate that enum column based on the tag. + sqlx::query("CREATE TABLE group_exprs(group_expr_id INTEGER PRIMARY KEY AUTOINCREMENT, group_id INTEGER, tag TEXT, is_logical bool, children JSON DEFAULT('[]'), predicates JSON DEFAULT('[]'))") + .execute(&self.db_conn) + .await?; + sqlx::query( + "CREATE TABLE predicates(predicate_id INTEGER PRIMARY KEY AUTOINCREMENT, data JSON)", + ) + .execute(&self.db_conn) + .await?; + Ok(()) + } +} + +pub async fn new_in_memory() -> anyhow::Result> { + let db_conn = sqlx::SqlitePool::connect("sqlite::memory:").await?; + Ok(PersistentMemo::new(db_conn).await) +} + +impl PersistentMemo { + #[cfg(test)] + async fn lookup_predicate(&self, pred_node: ArcPredNode) -> Option { + let data = T::serialize_pred(&pred_node); + let pred_id = sqlx::query("SELECT predicate_id FROM predicates WHERE data = ?") + .bind(&data) + .fetch_optional(&self.db_conn) + .await + .unwrap(); + pred_id.map(|row| PredId(row.get::(0) as usize)) + } + + /// This is inefficient: usually the optimizer should have a MemoRef instead of passing the full + /// rel node. Should be only used for debugging purpose. + #[cfg(test)] + #[async_recursion] + async fn get_expr_info(&self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { + let mut children_group_ids = Vec::new(); + for child in &rel_node.children { + let group = match child { + PlanNodeOrGroup::Group(group) => *group, + PlanNodeOrGroup::PlanNode(child) => { + let (group_id, _) = self.get_expr_info(child.clone()).await; + group_id + } + }; + children_group_ids.push(group.0); + } + let mut children_predicates = Vec::new(); + for pred in &rel_node.predicates { + let pred_id = self.lookup_predicate(pred.clone()).await.unwrap(); + children_predicates.push(pred_id.0); + } + let tag = T::serialize_plan_tag(rel_node.typ.clone()); + // We keep duplicated expressions in the table, but we retrieve the first expr_id + let row = sqlx::query("SELECT group_expr_id, group_id FROM group_exprs WHERE tag = ? AND children = ? AND predicates = ? ORDER BY group_expr_id") + .bind(&tag) + .bind(serde_json::to_value(&children_group_ids).unwrap()) + .bind(serde_json::to_value(&children_predicates).unwrap()) + .fetch_optional(&self.db_conn) + .await + .unwrap() + .unwrap(); + let expr_id = row.get::("group_expr_id"); + let expr_id = ExprId(expr_id as usize); + let group_id = row.get::("group_id"); + let group_id = GroupId(group_id as usize); + (group_id, expr_id) + } + + #[async_recursion] + async fn add_expr_to_group_inner( + &mut self, + rel_node: ArcPlanNode, + add_to_group: Option, + ) -> (GroupId, ExprId) { + let mut children_groups = Vec::new(); + for child in rel_node.children.iter() { + let group = match child { + PlanNodeOrGroup::Group(group) => { + // The user-provided group could contain a stale ID + self.reduce_group(*group).await + } + PlanNodeOrGroup::PlanNode(child) => { + let (group_id, _) = self.add_expr_to_group_inner(child.clone(), None).await; + group_id + } + }; + children_groups.push(group.0); + } + let mut predicates = Vec::new(); + for pred in rel_node.predicates.iter() { + let pred_id = self.add_new_pred(pred.clone()).await; + predicates.push(pred_id.0); + } + let tag = T::serialize_plan_tag(rel_node.typ.clone()); + // check if we already have an expr in the database + let row = + sqlx::query("SELECT group_expr_id, group_id FROM group_exprs WHERE tag = ? AND children = ? AND predicates = ?") + .bind(&tag) + .bind(serde_json::to_value(&children_groups).unwrap()) + .bind(serde_json::to_value(&predicates).unwrap()) + .fetch_optional(&self.db_conn) + .await + .unwrap(); + if let Some(row) = row { + let expr_id = row.get::("group_expr_id"); + let expr_id = ExprId(expr_id as usize); + let group_id = row.get::("group_id"); + let group_id = GroupId(group_id as usize); + if let Some(add_to_group) = add_to_group { + self.merge_group_inner(group_id, add_to_group).await; + (add_to_group, expr_id) + } else { + (group_id, expr_id) + } + } else { + let group_id = if let Some(add_to_group) = add_to_group { + add_to_group.0 as i64 + } else { + sqlx::query("INSERT INTO groups DEFAULT VALUES") + .execute(&self.db_conn) + .await + .unwrap() + .last_insert_rowid() + }; + let expr_id = sqlx::query( + "INSERT INTO group_exprs(group_id, tag, children, predicates, is_logical) VALUES (?, ?, ?, ?, ?)", + ) + .bind(group_id) + .bind(&tag) + .bind(serde_json::to_value(&children_groups).unwrap()) + .bind(serde_json::to_value(&predicates).unwrap()) + .bind(rel_node.typ.is_logical()) + .execute(&self.db_conn) + .await + .unwrap() + .last_insert_rowid(); + (GroupId(group_id as usize), ExprId(expr_id as usize)) + } + } + + async fn reduce_group(&self, group_id: GroupId) -> GroupId { + let row = sqlx::query("SELECT to_group_id FROM group_merges WHERE from_group_id = ?") + .bind(group_id.0 as i64) + .fetch_optional(&self.db_conn) + .await + .unwrap(); + if let Some(row) = row { + let to_group_id = row.get::(0); + GroupId(to_group_id as usize) + } else { + group_id + } + } + + #[async_recursion] + async fn merge_group_inner(&mut self, from_group: GroupId, to_group: GroupId) { + if from_group == to_group { + return; + } + // Add the merge record to the group merge table for resolve group in the future + sqlx::query("INSERT INTO group_merges(from_group_id, to_group_id) VALUES (?, ?)") + .bind(from_group.0 as i64) + .bind(to_group.0 as i64) + .execute(&self.db_conn) + .await + .unwrap(); + // Update the group merge table so that all to_group_id are updated to the new group_id + sqlx::query("UPDATE group_merges SET to_group_id = ? WHERE to_group_id = ?") + .bind(to_group.0 as i64) + .bind(from_group.0 as i64) + .execute(&self.db_conn) + .await + .unwrap(); + // Update the group_exprs table so that all group_id are updated to the new group_id + sqlx::query("UPDATE group_exprs SET group_id = ? WHERE group_id = ?") + .bind(to_group.0 as i64) + .bind(from_group.0 as i64) + .execute(&self.db_conn) + .await + .unwrap(); + // Update the children to have the new group_id (is there any way to do it in a single SQL?) + let res = sqlx::query("SELECT group_expr_id, children FROM group_exprs WHERE ? in (SELECT json_each.value FROM json_each(children))") + .bind(from_group.0 as i64) + .fetch_all(&self.db_conn) + .await + .unwrap(); + for row in res { + let group_expr_id = row.get::("group_expr_id"); + let children = row.get::("children"); + let children: Vec = serde_json::from_value(children).unwrap(); + let children: Vec = children + .into_iter() + .map(|x| if x == from_group.0 { to_group.0 } else { x }) + .collect(); + sqlx::query("UPDATE group_exprs SET children = ? WHERE group_expr_id = ?") + .bind(serde_json::to_value(&children).unwrap()) + .bind(group_expr_id) + .execute(&self.db_conn) + .await + .unwrap(); + } + // Find duplicate expressions + let res = sqlx::query("SELECT tag, children, predicates, count(group_expr_id) c FROM group_exprs GROUP BY tag, children, predicates HAVING c > 1") + .bind(from_group.0 as i64) + .fetch_all(&self.db_conn) + .await.unwrap(); + let mut pending_cascades_merging = Vec::new(); + for row in res { + let tag = row.get::("tag"); + let children = row.get::("children"); + let predicates = row.get::("predicates"); + // Find the current group ID of the expression + let group_ids = sqlx::query("SELECT group_id FROM group_exprs WHERE tag = ? AND children = ? AND predicates = ?") + .bind(&tag) + .bind(&children) + .bind(&predicates) + .fetch_all(&self.db_conn) + .await + .unwrap(); + assert!(group_ids.len() > 1); + let first_group_id = group_ids[0].get::(0); + for groups in group_ids.into_iter().skip(1) { + pending_cascades_merging.push(( + GroupId(first_group_id as usize), + GroupId(groups.get::(0) as usize), + )); + } + } + for (from_group, to_group) in pending_cascades_merging { + // We need to reduce because each merge would probably invalidate some groups in the + // last loop iteration. + let from_group = self.reduce_group(from_group).await; + let to_group = self.reduce_group(to_group).await; + self.merge_group_inner(from_group, to_group).await; + } + } + + async fn dump(&self) { + let groups = sqlx::query("SELECT group_id FROM groups") + .fetch_all(&self.db_conn) + .await + .unwrap(); + for group in groups { + let group_id = group.get::(0); + let exprs = sqlx::query("SELECT group_expr_id, tag, children, predicates FROM group_exprs WHERE group_id = ?") + .bind(group_id) + .fetch_all(&self.db_conn) + .await + .unwrap(); + println!("Group {}", group_id); + for expr in exprs { + let expr_id = expr.get::(0); + let tag = expr.get::(1); + let children = expr.get::(2); + let predicates = expr.get::(3); + println!(" Expr {} {} {} {}", expr_id, tag, children, predicates); + } + } + } +} + +#[async_trait] +impl Memo for PersistentMemo { + async fn add_new_expr(&mut self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { + self.add_expr_to_group_inner(rel_node, None).await + } + + async fn add_expr_to_group( + &mut self, + rel_node: PlanNodeOrGroup, + group_id: GroupId, + ) -> Option { + match rel_node { + PlanNodeOrGroup::Group(from_group) => { + self.merge_group_inner(from_group, group_id).await; + None + } + PlanNodeOrGroup::PlanNode(rel_node) => { + let (_, expr_id) = self.add_expr_to_group_inner(rel_node, Some(group_id)).await; + Some(expr_id) + } + } + } + + async fn add_new_pred(&mut self, pred_node: ArcPredNode) -> PredId { + let data = T::serialize_pred(&pred_node); + let pred_id_if_exists = sqlx::query("SELECT predicate_id FROM predicates WHERE data = ?") + .bind(&data) + .fetch_optional(&self.db_conn) + .await + .unwrap(); + if let Some(pred_id) = pred_id_if_exists { + return PredId(pred_id.get::(0) as usize); + } + let pred_id = sqlx::query("INSERT INTO predicates(data) VALUES (?)") + .bind(&data) + .execute(&self.db_conn) + .await + .unwrap() + .last_insert_rowid(); + PredId(pred_id as usize) + } + + async fn get_pred(&self, pred_id: PredId) -> ArcPredNode { + let pred_data = sqlx::query("SELECT data FROM predicates WHERE predicate_id = ?") + .bind(pred_id.0 as i64) + .fetch_one(&self.db_conn) + .await + .unwrap() + .get::(0); + T::deserialize_pred(pred_data) + } + + async fn get_group_id(&self, expr_id: ExprId) -> GroupId { + let group_id = sqlx::query("SELECT group_id FROM group_exprs WHERE group_expr_id = ?") + .bind(expr_id.0 as i64) + .fetch_one(&self.db_conn) + .await + .unwrap() + .get::(0); + GroupId(group_id as usize) + } + + async fn get_expr_memoed(&self, expr_id: ExprId) -> ArcMemoPlanNode { + let row = sqlx::query( + "SELECT tag, children, predicates FROM group_exprs WHERE group_expr_id = ?", + ) + .bind(expr_id.0 as i64) + .fetch_one(&self.db_conn) + .await + .unwrap(); + let tag = row.get::(0); + let children = row.get::(1); + let predicates = row.get::(2); + let children: Vec = serde_json::from_value(children).unwrap(); + let children = children.into_iter().map(|x| GroupId(x)).collect(); + let predicates: Vec = serde_json::from_value(predicates).unwrap(); + let predicates = predicates.into_iter().map(|x| PredId(x)).collect(); + MemoPlanNode { + typ: T::deserialize_plan_tag(serde_json::from_str(&tag).unwrap()), + children, + predicates, + } + .into() + } + + async fn get_all_group_ids(&self) -> Vec { + let group_ids = sqlx::query("SELECT group_id FROM groups ORDER BY group_id") + .fetch_all(&self.db_conn) + .await + .unwrap(); + let group_ids: Vec = group_ids + .into_iter() + .map(|row| GroupId(row.get::(0) as usize)) + .collect(); + group_ids + } + + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { + let expr_ids = sqlx::query( + "SELECT group_expr_id FROM group_exprs WHERE group_id = ? ORDER BY group_expr_id", + ) + .bind(group_id.0 as i64) + .fetch_all(&self.db_conn) + .await + .unwrap(); + let expr_ids: Vec = expr_ids + .into_iter() + .map(|row| ExprId(row.get::(0) as usize)) + .collect(); + expr_ids + } + + async fn estimated_plan_space(&self) -> usize { + unimplemented!() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + nodes::Value, + tests::common::{expr, group, join, list, project, scan, MemoTestRelTyp}, + }; + + async fn create_db_and_migrate() -> PersistentMemo { + let mut memo = new_in_memory::().await.unwrap(); + memo.setup().await.unwrap(); + memo + } + + #[tokio::test] + async fn setup_in_memory() { + create_db_and_migrate().await; + } + + #[tokio::test] + async fn add_predicate() { + let mut memo = create_db_and_migrate().await; + let pred_node = list(vec![expr(Value::Int32(233))]); + let p1 = memo.add_new_pred(pred_node.clone()).await; + let p2 = memo.add_new_pred(pred_node.clone()).await; + assert_eq!(p1, p2); + } + + #[tokio::test] + async fn add_expr() { + let mut memo = create_db_and_migrate().await; + let scan_node = scan("t1"); + let p1 = memo.add_new_expr(scan_node.clone()).await; + let p2 = memo.add_new_expr(scan_node.clone()).await; + assert_eq!(p1, p2); + } + + #[tokio::test] + async fn group_merge_1() { + let mut memo = create_db_and_migrate().await; + let (group_id, _) = memo + .add_new_expr(join(scan("t1"), scan("t2"), expr(Value::Bool(true)))) + .await; + memo.add_expr_to_group( + join(scan("t2"), scan("t1"), expr(Value::Bool(true))).into(), + group_id, + ) + .await; + assert_eq!(memo.get_all_exprs_in_group(group_id).await.len(), 2); + } + + #[tokio::test] + async fn group_merge_2() { + let mut memo = create_db_and_migrate().await; + let (group_id_1, _) = memo + .add_new_expr(project( + join(scan("t1"), scan("t2"), expr(Value::Bool(true))), + list(vec![expr(Value::Int64(1))]), + )) + .await; + let (group_id_2, _) = memo + .add_new_expr(project( + join(scan("t1"), scan("t2"), expr(Value::Bool(true))), + list(vec![expr(Value::Int64(1))]), + )) + .await; + assert_eq!(group_id_1, group_id_2); + } + + #[tokio::test] + async fn group_merge_3() { + let mut memo = create_db_and_migrate().await; + let expr1 = project(scan("t1"), list(vec![expr(Value::Int64(1))])); + let expr2 = project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])); + memo.add_new_expr(expr1.clone()).await; + memo.add_new_expr(expr2.clone()).await; + // merging two child groups causes parent to merge + let (group_id_expr, _) = memo.get_expr_info(scan("t1")).await; + memo.add_expr_to_group(scan("t1-alias").into(), group_id_expr) + .await; + let (group_1, _) = memo.get_expr_info(expr1).await; + let (group_2, _) = memo.get_expr_info(expr2).await; + assert_eq!(group_1, group_2); + } + + #[tokio::test] + async fn group_merge_4() { + let mut memo = create_db_and_migrate().await; + let expr1 = project( + project(scan("t1"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + let expr2 = project( + project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + memo.add_new_expr(expr1.clone()).await; + memo.add_new_expr(expr2.clone()).await; + // merge two child groups, cascading merge + let (group_id_expr, _) = memo.get_expr_info(scan("t1")).await; + memo.add_expr_to_group(scan("t1-alias").into(), group_id_expr) + .await; + let (group_1, _) = memo.get_expr_info(expr1.clone()).await; + let (group_2, _) = memo.get_expr_info(expr2.clone()).await; + assert_eq!(group_1, group_2); + let (group_1, _) = memo.get_expr_info(expr1.child_rel(0)).await; + let (group_2, _) = memo.get_expr_info(expr2.child_rel(0)).await; + assert_eq!(group_1, group_2); + } + + #[tokio::test] + async fn group_merge_5() { + let mut memo = create_db_and_migrate().await; + let expr1 = project( + project(scan("t1"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + let expr2 = project( + project(scan("t1-alias"), list(vec![expr(Value::Int64(1))])), + list(vec![expr(Value::Int64(2))]), + ); + let (_, expr1_id) = memo.add_new_expr(expr1.clone()).await; + let (_, expr2_id) = memo.add_new_expr(expr2.clone()).await; + + // experimenting with group id in expr (i.e., when apply rules) + let (scan_t1, _) = memo.get_expr_info(scan("t1")).await; + let pred = list(vec![expr(Value::Int64(1))]); + let proj_binding = project(group(scan_t1), pred); + let middle_proj_2 = memo.get_expr_memoed(expr2_id).await.children[0]; + + memo.add_expr_to_group(proj_binding.into(), middle_proj_2) + .await; + assert_eq!( + memo.get_expr_memoed(expr1_id).await, + memo.get_expr_memoed(expr2_id).await + ); // these two expressions are merged + assert_eq!( + memo.get_expr_info(expr1).await, + memo.get_expr_info(expr2).await + ); + } +} diff --git a/optd-ng-kernel/src/lib.rs b/optd-ng-kernel/src/lib.rs new file mode 100644 index 00000000..928c2895 --- /dev/null +++ b/optd-ng-kernel/src/lib.rs @@ -0,0 +1,10 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +pub mod cascades; +pub mod nodes; + +#[cfg(test)] +mod tests; diff --git a/optd-ng-kernel/src/nodes.rs b/optd-ng-kernel/src/nodes.rs new file mode 100644 index 00000000..1a30b5e4 --- /dev/null +++ b/optd-ng-kernel/src/nodes.rs @@ -0,0 +1,370 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! The RelNode is the basic data structure of the optimizer. It is dynamically typed and is +//! the internal representation of the plan nodes. + +use std::fmt::{Debug, Display}; +use std::hash::Hash; +use std::sync::Arc; + +use arrow_schema::DataType; +use chrono::NaiveDate; +use ordered_float::OrderedFloat; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::cascades::optimizer::GroupId; + +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct SerializableOrderedF64(pub OrderedFloat); + +impl Serialize for SerializableOrderedF64 { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + // Directly serialize the inner f64 value of the OrderedFloat + self.0 .0.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for SerializableOrderedF64 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Deserialize an f64 and wrap it in an OrderedFloat + let float = f64::deserialize(deserializer)?; + Ok(SerializableOrderedF64(OrderedFloat(float))) + } +} + +// TODO: why not use arrow types here? Do we really need to define our own Value type? +// Shouldn't we at least remove this from the core/engine? +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub enum Value { + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Int128(i128), + Float(SerializableOrderedF64), + String(Arc), + Bool(bool), + Date32(i32), + Decimal128(i128), + Serialized(Arc<[u8]>), +} + +impl std::fmt::Display for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::UInt8(x) => write!(f, "{x}(u8)"), + Self::UInt16(x) => write!(f, "{x}(u16)"), + Self::UInt32(x) => write!(f, "{x}(u32)"), + Self::UInt64(x) => write!(f, "{x}(u64)"), + Self::Int8(x) => write!(f, "{x}(i8)"), + Self::Int16(x) => write!(f, "{x}(i16)"), + Self::Int32(x) => write!(f, "{x}(i32)"), + Self::Int64(x) => write!(f, "{x}(i64)"), + Self::Int128(x) => write!(f, "{x}(i128)"), + Self::Float(x) => write!(f, "{}(float)", x.0), + Self::String(x) => write!(f, "\"{x}\""), + Self::Bool(x) => write!(f, "{x}"), + Self::Date32(x) => write!(f, "{x}(date32)"), + Self::Decimal128(x) => write!(f, "{x}(decimal128)"), + Self::Serialized(x) => write!(f, "", x.len()), + } + } +} + +/// The `as_*()` functions do not perform conversions. This is *unlike* the `as` +/// keyword in rust. +/// +/// If you want to perform conversions, use the `to_*()` functions. +impl Value { + pub fn as_u8(&self) -> u8 { + match self { + Value::UInt8(i) => *i, + _ => panic!("Value is not an u8"), + } + } + + pub fn as_u16(&self) -> u16 { + match self { + Value::UInt16(i) => *i, + _ => panic!("Value is not an u16"), + } + } + + pub fn as_u32(&self) -> u32 { + match self { + Value::UInt32(i) => *i, + _ => panic!("Value is not an u32"), + } + } + + pub fn as_u64(&self) -> u64 { + match self { + Value::UInt64(i) => *i, + _ => panic!("Value is not an u64"), + } + } + + pub fn as_i8(&self) -> i8 { + match self { + Value::Int8(i) => *i, + _ => panic!("Value is not an i8"), + } + } + + pub fn as_i16(&self) -> i16 { + match self { + Value::Int16(i) => *i, + _ => panic!("Value is not an i16"), + } + } + + pub fn as_i32(&self) -> i32 { + match self { + Value::Int32(i) => *i, + _ => panic!("Value is not an i32"), + } + } + + pub fn as_i64(&self) -> i64 { + match self { + Value::Int64(i) => *i, + _ => panic!("Value is not an i64"), + } + } + + pub fn as_i128(&self) -> i128 { + match self { + Value::Int128(i) => *i, + _ => panic!("Value is not an i128"), + } + } + + pub fn as_f64(&self) -> f64 { + match self { + Value::Float(i) => *i.0, + _ => panic!("Value is not an f64"), + } + } + + pub fn as_bool(&self) -> bool { + match self { + Value::Bool(i) => *i, + _ => panic!("Value is not a bool"), + } + } + + pub fn as_str(&self) -> Arc { + match self { + Value::String(i) => i.clone(), + _ => panic!("Value is not a string"), + } + } + + pub fn as_slice(&self) -> Arc<[u8]> { + match self { + Value::Serialized(i) => i.clone(), + _ => panic!("Value is not a serialized"), + } + } + + pub fn convert_to_type(&self, typ: DataType) -> Value { + match typ { + DataType::Int32 => Value::Int32(match self { + Value::Int32(i32) => *i32, + Value::Int64(i64) => (*i64).try_into().unwrap(), + _ => panic!("{self} could not be converted into an Int32"), + }), + DataType::Int64 => Value::Int64(match self { + Value::Int64(i64) => *i64, + Value::Int32(i32) => (*i32).into(), + _ => panic!("{self} could not be converted into an Int64"), + }), + DataType::UInt64 => Value::UInt64(match self { + Value::Int64(i64) => (*i64).try_into().unwrap(), + Value::UInt64(i64) => *i64, + Value::UInt32(i32) => (*i32).into(), + _ => panic!("{self} could not be converted into an UInt64"), + }), + DataType::Date32 => Value::Date32(match self { + Value::Date32(date32) => *date32, + Value::String(str) => { + let date = NaiveDate::parse_from_str(str, "%Y-%m-%d").unwrap(); + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let duration_since_epoch = date.signed_duration_since(epoch); + let days_since_epoch: i32 = duration_since_epoch.num_days() as i32; + days_since_epoch + } + _ => panic!("{self} could not be converted into an Date32"), + }), + _ => unimplemented!("Have not implemented convert_to_type for DataType {typ}"), + } + } +} + +pub trait NodeType: + PartialEq + Eq + Hash + Clone + 'static + Display + Debug + Send + Sync +{ + type PredType: PartialEq + Eq + Hash + Clone + 'static + Display + Debug + Send + Sync; + + fn is_logical(&self) -> bool; +} + +pub trait PersistentNodeType: NodeType { + fn serialize_pred(pred: &ArcPredNode) -> serde_json::Value; + + fn deserialize_pred(data: serde_json::Value) -> ArcPredNode; + + fn serialize_plan_tag(tag: Self) -> serde_json::Value; + + fn deserialize_plan_tag(data: serde_json::Value) -> Self; +} + +/// A pointer to a plan node +pub type ArcPlanNode = Arc>; + +/// A pointer to a predicate node +pub type ArcPredNode = Arc>; + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub enum PlanNodeOrGroup { + PlanNode(ArcPlanNode), + Group(GroupId), +} + +impl PlanNodeOrGroup { + pub fn is_materialized(&self) -> bool { + match self { + PlanNodeOrGroup::PlanNode(_) => true, + PlanNodeOrGroup::Group(_) => false, + } + } + + pub fn unwrap_typ(&self) -> T { + self.unwrap_plan_node().typ.clone() + } + + pub fn unwrap_plan_node(&self) -> ArcPlanNode { + match self { + PlanNodeOrGroup::PlanNode(node) => node.clone(), + PlanNodeOrGroup::Group(_) => panic!("Expected PlanNode, found Group"), + } + } + + pub fn unwrap_group(&self) -> GroupId { + match self { + PlanNodeOrGroup::PlanNode(_) => panic!("Expected Group, found PlanNode"), + PlanNodeOrGroup::Group(group_id) => *group_id, + } + } +} + +impl std::fmt::Display for PlanNodeOrGroup { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PlanNodeOrGroup::PlanNode(node) => write!(f, "{}", node), + PlanNodeOrGroup::Group(group_id) => write!(f, "{}", group_id), + } + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PlanNode { + /// A generic plan node type + pub typ: T, + /// Child plan nodes, which may be materialized or placeholder group IDs + /// based on how this node was initialized + pub children: Vec>, + /// Predicate nodes, which are always materialized + pub predicates: Vec>, +} + +impl std::fmt::Display for PlanNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}", self.typ)?; + for child in &self.children { + write!(f, " {}", child)?; + } + for pred in &self.predicates { + write!(f, " {}", pred)?; + } + write!(f, ")") + } +} + +impl PlanNode { + pub fn child(&self, idx: usize) -> PlanNodeOrGroup { + self.children[idx].clone() + } + + pub fn child_rel(&self, idx: usize) -> ArcPlanNode { + self.child(idx).unwrap_plan_node() + } + + pub fn predicate(&self, idx: usize) -> ArcPredNode { + self.predicates[idx].clone() + } +} + +impl From> for PlanNodeOrGroup { + fn from(value: PlanNode) -> Self { + Self::PlanNode(value.into()) + } +} + +impl From> for PlanNodeOrGroup { + fn from(value: ArcPlanNode) -> Self { + Self::PlanNode(value) + } +} + +impl From for PlanNodeOrGroup { + fn from(value: GroupId) -> Self { + Self::Group(value) + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PredNode { + /// A generic predicate node type + pub typ: T::PredType, + /// Child predicate nodes, always materialized + pub children: Vec>, + /// Data associated with the predicate, if any + pub data: Option, +} + +impl std::fmt::Display for PredNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}", self.typ)?; + for child in &self.children { + write!(f, " {}", child)?; + } + if let Some(data) = &self.data { + write!(f, " {}", data)?; + } + write!(f, ")") + } +} + +impl PredNode { + pub fn child(&self, idx: usize) -> ArcPredNode { + self.children[idx].clone() + } + + pub fn unwrap_data(&self) -> Value { + self.data.clone().unwrap() + } +} diff --git a/optd-ng-kernel/src/tests.rs b/optd-ng-kernel/src/tests.rs new file mode 100644 index 00000000..34994bf5 --- /dev/null +++ b/optd-ng-kernel/src/tests.rs @@ -0,0 +1 @@ +pub mod common; diff --git a/optd-ng-kernel/src/tests/common.rs b/optd-ng-kernel/src/tests/common.rs new file mode 100644 index 00000000..0a9f7fd1 --- /dev/null +++ b/optd-ng-kernel/src/tests/common.rs @@ -0,0 +1,292 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::{ + cascades::optimizer::GroupId, + nodes::{ + ArcPlanNode, ArcPredNode, NodeType, PersistentNodeType, PlanNode, PlanNodeOrGroup, + PredNode, Value, + }, +}; + +#[allow(dead_code)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub(crate) enum MemoTestRelTyp { + Join, + Project, + Scan, + Sort, + Filter, + Agg, + PhysicalNestedLoopJoin, + PhysicalProject, + PhysicalFilter, + PhysicalScan, + PhysicalSort, + PhysicalPartition, + PhysicalStreamingAgg, + PhysicalHashAgg, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub(crate) enum MemoTestPredTyp { + List, + Expr, + TableName, + ColumnRef, +} + +impl std::fmt::Display for MemoTestRelTyp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::fmt::Display for MemoTestPredTyp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl NodeType for MemoTestRelTyp { + type PredType = MemoTestPredTyp; + + fn is_logical(&self) -> bool { + matches!( + self, + Self::Project | Self::Scan | Self::Join | Self::Sort | Self::Filter + ) + } +} + +// TODO: move this into nodes.rs? +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct PersistentPredNode { + /// A generic predicate node type + pub typ: T::PredType, + /// Child predicate nodes, always materialized + pub children: Vec>, + /// Data associated with the predicate, if any + pub data: Option, +} + +impl From> for PredNode { + fn from(node: PersistentPredNode) -> Self { + PredNode { + typ: node.typ, + children: node + .children + .into_iter() + .map(|x| Arc::new(x.into())) + .collect(), + data: node.data, + } + } +} + +impl From> for PersistentPredNode { + fn from(node: PredNode) -> Self { + PersistentPredNode { + typ: node.typ, + children: node + .children + .into_iter() + .map(|x| x.as_ref().clone().into()) + .collect(), + data: node.data, + } + } +} + +impl PersistentNodeType for MemoTestRelTyp { + fn serialize_pred(pred: &ArcPredNode) -> serde_json::Value { + let node: PersistentPredNode = pred.as_ref().clone().into(); + serde_json::to_value(node).unwrap() + } + + fn deserialize_pred(data: serde_json::Value) -> ArcPredNode { + let node: PersistentPredNode = serde_json::from_value(data).unwrap(); + Arc::new(node.into()) + } + + fn serialize_plan_tag(tag: Self) -> serde_json::Value { + serde_json::to_value(tag).unwrap() + } + + fn deserialize_plan_tag(data: serde_json::Value) -> Self { + serde_json::from_value(data).unwrap() + } +} + +pub(crate) fn join( + left: impl Into>, + right: impl Into>, + cond: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::Join, + children: vec![left.into(), right.into()], + predicates: vec![cond], + }) +} + +#[allow(dead_code)] +pub(crate) fn agg( + input: impl Into>, + group_bys: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::Agg, + children: vec![input.into()], + predicates: vec![group_bys], + }) +} + +pub(crate) fn scan(table: &str) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::Scan, + children: vec![], + predicates: vec![table_name(table)], + }) +} + +pub(crate) fn table_name(table: &str) -> ArcPredNode { + Arc::new(PredNode { + typ: MemoTestPredTyp::TableName, + children: vec![], + data: Some(Value::String(table.to_string().into())), + }) +} + +pub(crate) fn project( + input: impl Into>, + expr_list: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::Project, + children: vec![input.into()], + predicates: vec![expr_list], + }) +} + +pub(crate) fn physical_nested_loop_join( + left: impl Into>, + right: impl Into>, + cond: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalNestedLoopJoin, + children: vec![left.into(), right.into()], + predicates: vec![cond], + }) +} + +#[allow(dead_code)] +pub(crate) fn physical_project( + input: impl Into>, + expr_list: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalProject, + children: vec![input.into()], + predicates: vec![expr_list], + }) +} + +pub(crate) fn physical_filter( + input: impl Into>, + cond: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalFilter, + children: vec![input.into()], + predicates: vec![cond], + }) +} + +pub(crate) fn physical_scan(table: &str) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalScan, + children: vec![], + predicates: vec![table_name(table)], + }) +} + +pub(crate) fn physical_sort( + input: impl Into>, + sort_expr: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalSort, + children: vec![input.into()], + predicates: vec![sort_expr], + }) +} + +#[allow(dead_code)] +pub(crate) fn physical_partition( + input: impl Into>, + partition_expr: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalPartition, + children: vec![input.into()], + predicates: vec![partition_expr], + }) +} + +pub(crate) fn physical_streaming_agg( + input: impl Into>, + group_bys: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalStreamingAgg, + children: vec![input.into()], + predicates: vec![group_bys], + }) +} + +pub(crate) fn physical_hash_agg( + input: impl Into>, + group_bys: ArcPredNode, +) -> ArcPlanNode { + Arc::new(PlanNode { + typ: MemoTestRelTyp::PhysicalHashAgg, + children: vec![input.into()], + predicates: vec![group_bys], + }) +} + +pub(crate) fn list(items: Vec>) -> ArcPredNode { + Arc::new(PredNode { + typ: MemoTestPredTyp::List, + children: items, + data: None, + }) +} + +pub(crate) fn expr(data: Value) -> ArcPredNode { + Arc::new(PredNode { + typ: MemoTestPredTyp::Expr, + children: vec![], + data: Some(data), + }) +} + +pub(crate) fn column_ref(col: &str) -> ArcPredNode { + Arc::new(PredNode { + typ: MemoTestPredTyp::ColumnRef, + children: vec![], + data: Some(Value::String(col.to_string().into())), + }) +} + +pub(crate) fn group(group_id: GroupId) -> PlanNodeOrGroup { + PlanNodeOrGroup::Group(group_id) +}