diff --git a/clickhouse/columns/factory.cpp b/clickhouse/columns/factory.cpp index fbd57889..e003b7f5 100644 --- a/clickhouse/columns/factory.cpp +++ b/clickhouse/columns/factory.cpp @@ -23,10 +23,26 @@ #include "../exceptions.h" #include +#include namespace clickhouse { namespace { +// Like Python's list's []: +// * 0 - first element +// * 1 - second element +// * -1 - last element +// * -2 - one before last, etc. +const auto& GetASTChildElement(const TypeAst & ast, int position) { + if (static_cast(abs(position)) >= ast.elements.size()) + throw ValidationError("AST child element index out of bounds: " + std::to_string(position)); + + if (position < 0) + position = ast.elements.size() + position; + + return ast.elements[static_cast(position)]; +} + static ColumnRef CreateTerminalColumn(const TypeAst& ast) { switch (ast.code) { case Type::Void: @@ -58,24 +74,24 @@ static ColumnRef CreateTerminalColumn(const TypeAst& ast) { return std::make_shared(); case Type::Decimal: - return std::make_shared(ast.elements.front().value, ast.elements.back().value); + return std::make_shared(GetASTChildElement(ast, 0).value, GetASTChildElement(ast, -1).value); case Type::Decimal32: - return std::make_shared(9, ast.elements.front().value); + return std::make_shared(9, GetASTChildElement(ast, 0).value); case Type::Decimal64: - return std::make_shared(18, ast.elements.front().value); + return std::make_shared(18, GetASTChildElement(ast, 0).value); case Type::Decimal128: - return std::make_shared(38, ast.elements.front().value); + return std::make_shared(38, GetASTChildElement(ast, 0).value); case Type::String: return std::make_shared(); case Type::FixedString: - return std::make_shared(ast.elements.front().value); + return std::make_shared(GetASTChildElement(ast, 0).value); case Type::DateTime: if (ast.elements.empty()) { return std::make_shared(); } else { - return std::make_shared(ast.elements[0].value_string); + return std::make_shared(GetASTChildElement(ast, 0).value_string); } case Type::DateTime64: if (ast.elements.empty()) { @@ -120,13 +136,13 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti switch (ast.meta) { case TypeAst::Array: { return std::make_shared( - CreateColumnFromAst(ast.elements.front(), settings) + CreateColumnFromAst(GetASTChildElement(ast, 0), settings) ); } case TypeAst::Nullable: { return std::make_shared( - CreateColumnFromAst(ast.elements.front(), settings), + CreateColumnFromAst(GetASTChildElement(ast, 0), settings), std::make_shared() ); } @@ -159,9 +175,10 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti enum_items.reserve(ast.elements.size() / 2); for (size_t i = 0; i < ast.elements.size(); i += 2) { - enum_items.push_back( - Type::EnumItem{ ast.elements[i].value_string, - (int16_t)ast.elements[i + 1].value }); + enum_items.push_back(Type::EnumItem{ + ast.elements[i].value_string, + static_cast(ast.elements[i + 1].value) + }); } if (ast.code == Type::Enum8) { @@ -176,14 +193,14 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti break; } case TypeAst::LowCardinality: { - const auto nested = ast.elements.front(); + const auto nested = GetASTChildElement(ast, 0); if (settings.low_cardinality_as_wrapped_column) { switch (nested.code) { // TODO (nemkov): update this to maximize code reuse. case Type::String: return std::make_shared>(); case Type::FixedString: - return std::make_shared>(nested.elements.front().value); + return std::make_shared>(GetASTChildElement(nested, 0).value); case Type::Nullable: throw UnimplementedError("LowCardinality(" + nested.name + ") is not supported with LowCardinalityAsWrappedColumn on"); default: @@ -196,11 +213,11 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti case Type::String: return std::make_shared>(); case Type::FixedString: - return std::make_shared>(nested.elements.front().value); + return std::make_shared>(GetASTChildElement(nested, 0).value); case Type::Nullable: return std::make_shared( std::make_shared( - CreateColumnFromAst(nested.elements.front(), settings), + CreateColumnFromAst(GetASTChildElement(nested, 0), settings), std::make_shared() ) ); @@ -210,7 +227,7 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti } } case TypeAst::SimpleAggregateFunction: { - return CreateTerminalColumn(ast.elements.back()); + return CreateTerminalColumn(GetASTChildElement(ast, -1)); } case TypeAst::Map: { diff --git a/clickhouse/types/type_parser.cpp b/clickhouse/types/type_parser.cpp index e16aadb5..1ec7a7b7 100644 --- a/clickhouse/types/type_parser.cpp +++ b/clickhouse/types/type_parser.cpp @@ -1,10 +1,21 @@ #include "type_parser.h" +#include "clickhouse/exceptions.h" +#include "clickhouse/base/platform.h" // for _win_ + #include +#include #include #include #include +#if defined _win_ +#include +#else +#include +#endif + + namespace clickhouse { bool TypeAst::operator==(const TypeAst & other) const { @@ -16,6 +27,7 @@ bool TypeAst::operator==(const TypeAst & other) const { } static const std::unordered_map kTypeCode = { + { "Void", Type::Void }, { "Int8", Type::Int8 }, { "Int16", Type::Int16 }, { "Int32", Type::Int32 }, @@ -41,23 +53,38 @@ static const std::unordered_map kTypeCode = { { "IPv4", Type::IPv4 }, { "IPv6", Type::IPv6 }, { "Int128", Type::Int128 }, +// { "UInt128", Type::UInt128 }, { "Decimal", Type::Decimal }, { "Decimal32", Type::Decimal32 }, { "Decimal64", Type::Decimal64 }, { "Decimal128", Type::Decimal128 }, { "LowCardinality", Type::LowCardinality }, - { "Map", Type::Map}, - { "Point", Type::Point}, - { "Ring", Type::Ring}, - { "Polygon", Type::Polygon}, - { "MultiPolygon", Type::MultiPolygon}, + { "Map", Type::Map }, + { "Point", Type::Point }, + { "Ring", Type::Ring }, + { "Polygon", Type::Polygon }, + { "MultiPolygon", Type::MultiPolygon }, }; +template +inline int CompateStringsCaseInsensitive(const L& left, const R& right) { + int64_t size_diff = left.size() - right.size(); + if (size_diff != 0) + return size_diff > 0 ? 1 : -1; + +#if defined _win_ + return _strnicmp(left.data(), right.data(), left.size()); +#else + return strncasecmp(left.data(), right.data(), left.size()); +#endif +} + static Type::Code GetTypeCode(const std::string& name) { auto it = kTypeCode.find(name); if (it != kTypeCode.end()) { return it->second; } + return Type::Void; } @@ -97,6 +124,17 @@ static TypeAst::Meta GetTypeMeta(const StringView& name) { return TypeAst::Terminal; } +bool ValidateAST(const TypeAst& ast) { + // Void terminal that is not actually "void" produced when unknown type is encountered. + if (ast.meta == TypeAst::Terminal + && ast.code == Type::Void + && CompateStringsCaseInsensitive(ast.name, std::string_view("void")) != 0) + //throw UnimplementedError("Unsupported type: " + ast.name); + return false; + + return true; +} + TypeParser::TypeParser(const StringView& name) : cur_(name.data()) @@ -111,6 +149,7 @@ bool TypeParser::Parse(TypeAst* type) { type_ = type; open_elements_.push(type_); + size_t processed_tokens = 0; do { const Token & token = NextToken(); switch (token.type) { @@ -159,11 +198,17 @@ bool TypeParser::Parse(TypeAst* type) { // Ubalanced braces, brackets, etc is an error. if (open_elements_.size() != 1) return false; - return true; + + // Empty input string, no tokens produced + if (processed_tokens == 0) + return false; + + return ValidateAST(*type); } case Token::Invalid: return false; } + ++processed_tokens; } while (true); } diff --git a/ut/CreateColumnByType_ut.cpp b/ut/CreateColumnByType_ut.cpp index fb7ffd85..fecf0ea3 100644 --- a/ut/CreateColumnByType_ut.cpp +++ b/ut/CreateColumnByType_ut.cpp @@ -50,6 +50,12 @@ TEST(CreateColumnByType, DateTime) { ASSERT_EQ(CreateColumnByType("DateTime64(3, 'UTC')")->As()->Timezone(), "UTC"); } +TEST(CreateColumnByType, AggregateFunction) { + EXPECT_EQ(nullptr, CreateColumnByType("AggregateFunction(argMax, Int32, DateTime64(3))")); + EXPECT_EQ(nullptr, CreateColumnByType("AggregateFunction(argMax, FIxedString(10), DateTime64(3, 'UTC'))")); +} + + class CreateColumnByTypeWithName : public ::testing::TestWithParam {}; diff --git a/ut/client_ut.cpp b/ut/client_ut.cpp index 62641c19..f86f3ff7 100644 --- a/ut/client_ut.cpp +++ b/ut/client_ut.cpp @@ -1313,6 +1313,25 @@ TEST_P(ClientCase, OnProfileEvents) { } } +TEST_P(ClientCase, SelectAggregateFunction) { + // Verifies that perofing SELECT value of type AggregateFunction(...) doesn't crash the client. + // For details: https://github.com/ClickHouse/clickhouse-cpp/issues/266 + client_->Execute("CREATE TEMPORARY TABLE IF NOT EXISTS tableplus_crash_example (col AggregateFunction(argMax, Int32, DateTime(3))) engine = Memory"); + client_->Execute("insert into tableplus_crash_example values (unhex('010000000001089170A883010000'))"); + + client_->Select("select version()", + [&](const Block& block) { + std::cerr << PrettyPrintBlock{block} << std::endl; + }); + + // Column type `AggregateFunction` is not supported. + EXPECT_THROW(client_->Select("select toTypeName(col), col from tableplus_crash_example", + [&](const Block& block) { + std::cerr << PrettyPrintBlock{block} << std::endl; + }), clickhouse::UnimplementedError); +} + + const auto LocalHostEndpoint = ClientOptions() .SetHost( getEnvOrDefault("CLICKHOUSE_HOST", "localhost")) .SetPort( getEnvOrDefault("CLICKHOUSE_PORT", "9000")) diff --git a/ut/type_parser_ut.cpp b/ut/type_parser_ut.cpp index ee1258b7..b0193ded 100644 --- a/ut/type_parser_ut.cpp +++ b/ut/type_parser_ut.cpp @@ -239,3 +239,34 @@ TEST(TypeParserCase, ParseMap) { ASSERT_EQ(ast.elements[1].meta, TypeAst::Terminal); ASSERT_EQ(ast.elements[1].name, "String"); } + +TEST(TypeParser, EmptyName) { + { + TypeAst ast; + EXPECT_EQ(false, TypeParser("").Parse(&ast)); + } + + { + TypeAst ast; + EXPECT_EQ(false, TypeParser(" ").Parse(&ast)); + } +} + +TEST(ParseTypeName, EmptyName) { + // Empty and invalid names shouldn't produce any AST and shoudn't crash + EXPECT_EQ(nullptr, ParseTypeName("")); + EXPECT_EQ(nullptr, ParseTypeName(" ")); + EXPECT_EQ(nullptr, ParseTypeName(std::string(5, '\0'))); +} + +TEST(TypeParser, AggregateFunction) { + { + TypeAst ast; + EXPECT_FALSE(TypeParser("AggregateFunction(argMax, Int32, DateTime(3))").Parse(&ast)); + } + + { + TypeAst ast; + EXPECT_FALSE(TypeParser("AggregateFunction(argMax, LowCardinality(Nullable(FixedString(4))), DateTime(3, 'UTC'))").Parse(&ast)); + } +} diff --git a/ut/types_ut.cpp b/ut/types_ut.cpp index 3f795269..7af343b5 100644 --- a/ut/types_ut.cpp +++ b/ut/types_ut.cpp @@ -80,7 +80,7 @@ TEST(TypesCase, IsEqual) { const std::string type_names[] = { "UInt8", "Int8", - "UInt128", +// "UInt128", "String", "FixedString(0)", "FixedString(10000)", @@ -128,7 +128,11 @@ TEST(TypesCase, IsEqual) { EXPECT_TRUE(type->IsEqual(*type)); for (const auto & other_type_name : type_names) { - const auto other_type = clickhouse::CreateColumnByType(other_type_name)->Type(); + SCOPED_TRACE(other_type_name); + const auto other_column = clickhouse::CreateColumnByType(other_type_name); + ASSERT_NE(nullptr, other_column); + + const auto other_type = other_column->Type(); const auto should_be_equal = type_name == other_type_name; EXPECT_EQ(should_be_equal, type->IsEqual(other_type))