Skip to content

Fix crash on invalid AST #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions clickhouse/columns/factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,26 @@
#include "../exceptions.h"

#include <stdexcept>
#include <string>

namespace clickhouse {
namespace {

// Like Python's list's []:
// * 0 - first element
// * 1 - second element
// * -1 - last element
// * -2 - one before last, etc.
const auto& GetASTChildElement(const TypeAst & ast, int position) {
if (static_cast<size_t>(abs(position)) >= ast.elements.size())
throw ValidationError("AST child element index out of bounds: " + std::to_string(position));

if (position < 0)
position = ast.elements.size() + position;

return ast.elements[static_cast<size_t>(position)];
}

static ColumnRef CreateTerminalColumn(const TypeAst& ast) {
switch (ast.code) {
case Type::Void:
Expand Down Expand Up @@ -58,24 +74,24 @@ static ColumnRef CreateTerminalColumn(const TypeAst& ast) {
return std::make_shared<ColumnFloat64>();

case Type::Decimal:
return std::make_shared<ColumnDecimal>(ast.elements.front().value, ast.elements.back().value);
return std::make_shared<ColumnDecimal>(GetASTChildElement(ast, 0).value, GetASTChildElement(ast, -1).value);
case Type::Decimal32:
return std::make_shared<ColumnDecimal>(9, ast.elements.front().value);
return std::make_shared<ColumnDecimal>(9, GetASTChildElement(ast, 0).value);
case Type::Decimal64:
return std::make_shared<ColumnDecimal>(18, ast.elements.front().value);
return std::make_shared<ColumnDecimal>(18, GetASTChildElement(ast, 0).value);
case Type::Decimal128:
return std::make_shared<ColumnDecimal>(38, ast.elements.front().value);
return std::make_shared<ColumnDecimal>(38, GetASTChildElement(ast, 0).value);

case Type::String:
return std::make_shared<ColumnString>();
case Type::FixedString:
return std::make_shared<ColumnFixedString>(ast.elements.front().value);
return std::make_shared<ColumnFixedString>(GetASTChildElement(ast, 0).value);

case Type::DateTime:
if (ast.elements.empty()) {
return std::make_shared<ColumnDateTime>();
} else {
return std::make_shared<ColumnDateTime>(ast.elements[0].value_string);
return std::make_shared<ColumnDateTime>(GetASTChildElement(ast, 0).value_string);
}
case Type::DateTime64:
if (ast.elements.empty()) {
Expand Down Expand Up @@ -120,13 +136,13 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti
switch (ast.meta) {
case TypeAst::Array: {
return std::make_shared<ColumnArray>(
CreateColumnFromAst(ast.elements.front(), settings)
CreateColumnFromAst(GetASTChildElement(ast, 0), settings)
);
}

case TypeAst::Nullable: {
return std::make_shared<ColumnNullable>(
CreateColumnFromAst(ast.elements.front(), settings),
CreateColumnFromAst(GetASTChildElement(ast, 0), settings),
std::make_shared<ColumnUInt8>()
);
}
Expand Down Expand Up @@ -159,9 +175,10 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti

enum_items.reserve(ast.elements.size() / 2);
for (size_t i = 0; i < ast.elements.size(); i += 2) {
enum_items.push_back(
Type::EnumItem{ ast.elements[i].value_string,
(int16_t)ast.elements[i + 1].value });
enum_items.push_back(Type::EnumItem{
ast.elements[i].value_string,
static_cast<int16_t>(ast.elements[i + 1].value)
});
}

if (ast.code == Type::Enum8) {
Expand All @@ -176,14 +193,14 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti
break;
}
case TypeAst::LowCardinality: {
const auto nested = ast.elements.front();
const auto nested = GetASTChildElement(ast, 0);
if (settings.low_cardinality_as_wrapped_column) {
switch (nested.code) {
// TODO (nemkov): update this to maximize code reuse.
case Type::String:
return std::make_shared<LowCardinalitySerializationAdaptor<ColumnString>>();
case Type::FixedString:
return std::make_shared<LowCardinalitySerializationAdaptor<ColumnFixedString>>(nested.elements.front().value);
return std::make_shared<LowCardinalitySerializationAdaptor<ColumnFixedString>>(GetASTChildElement(nested, 0).value);
case Type::Nullable:
throw UnimplementedError("LowCardinality(" + nested.name + ") is not supported with LowCardinalityAsWrappedColumn on");
default:
Expand All @@ -196,11 +213,11 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti
case Type::String:
return std::make_shared<ColumnLowCardinalityT<ColumnString>>();
case Type::FixedString:
return std::make_shared<ColumnLowCardinalityT<ColumnFixedString>>(nested.elements.front().value);
return std::make_shared<ColumnLowCardinalityT<ColumnFixedString>>(GetASTChildElement(nested, 0).value);
case Type::Nullable:
return std::make_shared<ColumnLowCardinality>(
std::make_shared<ColumnNullable>(
CreateColumnFromAst(nested.elements.front(), settings),
CreateColumnFromAst(GetASTChildElement(nested, 0), settings),
std::make_shared<ColumnUInt8>()
)
);
Expand All @@ -210,7 +227,7 @@ static ColumnRef CreateColumnFromAst(const TypeAst& ast, CreateColumnByTypeSetti
}
}
case TypeAst::SimpleAggregateFunction: {
return CreateTerminalColumn(ast.elements.back());
return CreateTerminalColumn(GetASTChildElement(ast, -1));
}

case TypeAst::Map: {
Expand Down
57 changes: 51 additions & 6 deletions clickhouse/types/type_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
#include "type_parser.h"

#include "clickhouse/exceptions.h"
#include "clickhouse/base/platform.h" // for _win_

#include <algorithm>
#include <cmath>
#include <map>
#include <mutex>
#include <unordered_map>

#if defined _win_
#include <string.h>
#else
#include <strings.h>
#endif


namespace clickhouse {

bool TypeAst::operator==(const TypeAst & other) const {
Expand All @@ -16,6 +27,7 @@ bool TypeAst::operator==(const TypeAst & other) const {
}

static const std::unordered_map<std::string, Type::Code> kTypeCode = {
{ "Void", Type::Void },
{ "Int8", Type::Int8 },
{ "Int16", Type::Int16 },
{ "Int32", Type::Int32 },
Expand All @@ -41,23 +53,38 @@ static const std::unordered_map<std::string, Type::Code> kTypeCode = {
{ "IPv4", Type::IPv4 },
{ "IPv6", Type::IPv6 },
{ "Int128", Type::Int128 },
// { "UInt128", Type::UInt128 },
{ "Decimal", Type::Decimal },
{ "Decimal32", Type::Decimal32 },
{ "Decimal64", Type::Decimal64 },
{ "Decimal128", Type::Decimal128 },
{ "LowCardinality", Type::LowCardinality },
{ "Map", Type::Map},
{ "Point", Type::Point},
{ "Ring", Type::Ring},
{ "Polygon", Type::Polygon},
{ "MultiPolygon", Type::MultiPolygon},
{ "Map", Type::Map },
{ "Point", Type::Point },
{ "Ring", Type::Ring },
{ "Polygon", Type::Polygon },
{ "MultiPolygon", Type::MultiPolygon },
};

template <typename L, typename R>
inline int CompateStringsCaseInsensitive(const L& left, const R& right) {
int64_t size_diff = left.size() - right.size();
if (size_diff != 0)
return size_diff > 0 ? 1 : -1;

#if defined _win_
return _strnicmp(left.data(), right.data(), left.size());
#else
return strncasecmp(left.data(), right.data(), left.size());
#endif
}

static Type::Code GetTypeCode(const std::string& name) {
auto it = kTypeCode.find(name);
if (it != kTypeCode.end()) {
return it->second;
}

return Type::Void;
}

Expand Down Expand Up @@ -97,6 +124,17 @@ static TypeAst::Meta GetTypeMeta(const StringView& name) {
return TypeAst::Terminal;
}

bool ValidateAST(const TypeAst& ast) {
// Void terminal that is not actually "void" produced when unknown type is encountered.
if (ast.meta == TypeAst::Terminal
&& ast.code == Type::Void
&& CompateStringsCaseInsensitive(ast.name, std::string_view("void")) != 0)
//throw UnimplementedError("Unsupported type: " + ast.name);
return false;

return true;
}


TypeParser::TypeParser(const StringView& name)
: cur_(name.data())
Expand All @@ -111,6 +149,7 @@ bool TypeParser::Parse(TypeAst* type) {
type_ = type;
open_elements_.push(type_);

size_t processed_tokens = 0;
do {
const Token & token = NextToken();
switch (token.type) {
Expand Down Expand Up @@ -159,11 +198,17 @@ bool TypeParser::Parse(TypeAst* type) {
// Ubalanced braces, brackets, etc is an error.
if (open_elements_.size() != 1)
return false;
return true;

// Empty input string, no tokens produced
if (processed_tokens == 0)
return false;

return ValidateAST(*type);
}
case Token::Invalid:
return false;
}
++processed_tokens;
} while (true);
}

Expand Down
6 changes: 6 additions & 0 deletions ut/CreateColumnByType_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ TEST(CreateColumnByType, DateTime) {
ASSERT_EQ(CreateColumnByType("DateTime64(3, 'UTC')")->As<ColumnDateTime64>()->Timezone(), "UTC");
}

TEST(CreateColumnByType, AggregateFunction) {
EXPECT_EQ(nullptr, CreateColumnByType("AggregateFunction(argMax, Int32, DateTime64(3))"));
EXPECT_EQ(nullptr, CreateColumnByType("AggregateFunction(argMax, FIxedString(10), DateTime64(3, 'UTC'))"));
}


class CreateColumnByTypeWithName : public ::testing::TestWithParam<const char* /*Column Type String*/>
{};

Expand Down
19 changes: 19 additions & 0 deletions ut/client_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,25 @@ TEST_P(ClientCase, OnProfileEvents) {
}
}

TEST_P(ClientCase, SelectAggregateFunction) {
// Verifies that perofing SELECT value of type AggregateFunction(...) doesn't crash the client.
// For details: https://github.com/ClickHouse/clickhouse-cpp/issues/266
client_->Execute("CREATE TEMPORARY TABLE IF NOT EXISTS tableplus_crash_example (col AggregateFunction(argMax, Int32, DateTime(3))) engine = Memory");
client_->Execute("insert into tableplus_crash_example values (unhex('010000000001089170A883010000'))");

client_->Select("select version()",
[&](const Block& block) {
std::cerr << PrettyPrintBlock{block} << std::endl;
});

// Column type `AggregateFunction` is not supported.
EXPECT_THROW(client_->Select("select toTypeName(col), col from tableplus_crash_example",
[&](const Block& block) {
std::cerr << PrettyPrintBlock{block} << std::endl;
}), clickhouse::UnimplementedError);
}


const auto LocalHostEndpoint = ClientOptions()
.SetHost( getEnvOrDefault("CLICKHOUSE_HOST", "localhost"))
.SetPort( getEnvOrDefault<size_t>("CLICKHOUSE_PORT", "9000"))
Expand Down
31 changes: 31 additions & 0 deletions ut/type_parser_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,34 @@ TEST(TypeParserCase, ParseMap) {
ASSERT_EQ(ast.elements[1].meta, TypeAst::Terminal);
ASSERT_EQ(ast.elements[1].name, "String");
}

TEST(TypeParser, EmptyName) {
{
TypeAst ast;
EXPECT_EQ(false, TypeParser("").Parse(&ast));
}

{
TypeAst ast;
EXPECT_EQ(false, TypeParser(" ").Parse(&ast));
}
}

TEST(ParseTypeName, EmptyName) {
// Empty and invalid names shouldn't produce any AST and shoudn't crash
EXPECT_EQ(nullptr, ParseTypeName(""));
EXPECT_EQ(nullptr, ParseTypeName(" "));
EXPECT_EQ(nullptr, ParseTypeName(std::string(5, '\0')));
}

TEST(TypeParser, AggregateFunction) {
{
TypeAst ast;
EXPECT_FALSE(TypeParser("AggregateFunction(argMax, Int32, DateTime(3))").Parse(&ast));
}

{
TypeAst ast;
EXPECT_FALSE(TypeParser("AggregateFunction(argMax, LowCardinality(Nullable(FixedString(4))), DateTime(3, 'UTC'))").Parse(&ast));
}
}
8 changes: 6 additions & 2 deletions ut/types_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(TypesCase, IsEqual) {
const std::string type_names[] = {
"UInt8",
"Int8",
"UInt128",
// "UInt128",
"String",
"FixedString(0)",
"FixedString(10000)",
Expand Down Expand Up @@ -128,7 +128,11 @@ TEST(TypesCase, IsEqual) {
EXPECT_TRUE(type->IsEqual(*type));

for (const auto & other_type_name : type_names) {
const auto other_type = clickhouse::CreateColumnByType(other_type_name)->Type();
SCOPED_TRACE(other_type_name);
const auto other_column = clickhouse::CreateColumnByType(other_type_name);
ASSERT_NE(nullptr, other_column);

const auto other_type = other_column->Type();

const auto should_be_equal = type_name == other_type_name;
EXPECT_EQ(should_be_equal, type->IsEqual(other_type))
Expand Down