Skip to content

Commit f425fc8

Browse files
committed
[feat](func) Support regexp_count function
1 parent 2c5effb commit f425fc8

File tree

6 files changed

+520
-0
lines changed

6 files changed

+520
-0
lines changed

be/src/vec/functions/function_regexp.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "vec/core/types.h"
4444
#include "vec/data_types/data_type.h"
4545
#include "vec/data_types/data_type_nullable.h"
46+
#include "vec/data_types/data_type_number.h"
4647
#include "vec/data_types/data_type_string.h"
4748
#include "vec/functions/function.h"
4849
#include "vec/functions/simple_function_factory.h"
@@ -501,6 +502,84 @@ struct RegexpExtractAllImpl {
501502
}
502503
};
503504

505+
struct RegexpCountImpl {
506+
static constexpr auto name = "regexp_count";
507+
508+
static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[],
509+
size_t input_rows_count, ColumnInt64::Container& result_data,
510+
NullMap& null_map) {
511+
const auto* str_col = check_and_get_column<ColumnString>(argument_columns[0].get());
512+
const auto* pattern_col = check_and_get_column<ColumnString>(argument_columns[1].get());
513+
for (int i = 0; i < input_rows_count; ++i) {
514+
if (null_map[i]) {
515+
result_data[i] = 0;
516+
continue;
517+
}
518+
result_data[i] = _execute_inner_loop<false>(context, str_col, pattern_col, null_map, i);
519+
}
520+
}
521+
522+
static void execute_impl_const_args(FunctionContext* context, ColumnPtr argument_columns[],
523+
size_t input_rows_count,
524+
ColumnInt64::Container& result_data, NullMap& null_map) {
525+
const auto* str_col = check_and_get_column<ColumnString>(argument_columns[0].get());
526+
const auto* pattern_col = check_and_get_column<ColumnString>(argument_columns[1].get());
527+
for (int i = 0; i < input_rows_count; ++i) {
528+
if (null_map[i]) {
529+
result_data[i] = 0;
530+
continue;
531+
}
532+
result_data[i] = _execute_inner_loop<true>(context, str_col, pattern_col, null_map, i);
533+
}
534+
}
535+
536+
template <bool Const>
537+
static int64_t _execute_inner_loop(FunctionContext* context, const ColumnString* str_col,
538+
const ColumnString* pattern_col, NullMap& null_map,
539+
const size_t index_now) {
540+
re2::RE2* re = reinterpret_cast<re2::RE2*>(
541+
context->get_function_state(FunctionContext::THREAD_LOCAL));
542+
std::unique_ptr<re2::RE2> scoped_re;
543+
if (re == nullptr) {
544+
std::string error_str;
545+
const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, Const));
546+
bool st = StringFunctions::compile_regex(pattern, &error_str, StringRef(), StringRef(),
547+
scoped_re);
548+
if (!st) {
549+
context->add_warning(error_str.c_str());
550+
null_map[index_now] = 1;
551+
return 0;
552+
}
553+
re = scoped_re.get();
554+
}
555+
556+
const auto& str = str_col->get_data_at(index_now);
557+
558+
int64_t count = 0;
559+
size_t pos = 0;
560+
while (pos < str.size) {
561+
auto str_pos = str.data + pos;
562+
auto str_size = str.size - pos;
563+
re2::StringPiece str_sp_current = re2::StringPiece(str_pos, str_size);
564+
re2::StringPiece match;
565+
566+
bool success = re->Match(str_sp_current, 0, str_size, re2::RE2::UNANCHORED, &match, 1);
567+
if (!success) {
568+
break;
569+
}
570+
if (match.empty()) {
571+
pos += 1;
572+
continue;
573+
}
574+
count++;
575+
size_t match_start = match.data() - str_sp_current.data();
576+
pos += match_start + match.size();
577+
}
578+
579+
return count;
580+
}
581+
};
582+
504583
// template FunctionRegexpFunctionality is used for regexp_xxxx series functions, not for regexp match.
505584
template <typename Impl>
506585
class FunctionRegexpFunctionality : public IFunction {
@@ -600,6 +679,77 @@ class FunctionRegexpFunctionality : public IFunction {
600679
}
601680
};
602681

682+
class FunctionRegexpCount : public IFunction {
683+
public:
684+
static constexpr auto name = "regexp_count";
685+
686+
static FunctionPtr create() { return std::make_shared<FunctionRegexpCount>(); }
687+
688+
String get_name() const override { return name; }
689+
690+
size_t get_number_of_arguments() const override { return 2; }
691+
692+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
693+
DataTypePtr int64_type = std::make_shared<DataTypeInt64>();
694+
return make_nullable(int64_type);
695+
}
696+
697+
Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
698+
if (scope == FunctionContext::THREAD_LOCAL) {
699+
if (context->is_col_constant(1)) {
700+
DCHECK(!context->get_function_state(scope));
701+
const auto pattern_col = context->get_constant_col(1)->column_ptr;
702+
const auto& pattern = pattern_col->get_data_at(0);
703+
if (pattern.size == 0) {
704+
return Status::OK();
705+
}
706+
707+
std::string error_str;
708+
std::unique_ptr<re2::RE2> scoped_re;
709+
bool st = StringFunctions::compile_regex(pattern, &error_str, StringRef(),
710+
StringRef(), scoped_re);
711+
if (!st) {
712+
context->set_error(error_str.c_str());
713+
return Status::InvalidArgument(error_str);
714+
}
715+
std::shared_ptr<re2::RE2> re(scoped_re.release());
716+
context->set_function_state(scope, re);
717+
}
718+
}
719+
return Status::OK();
720+
}
721+
722+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
723+
uint32_t result, size_t input_rows_count) const override {
724+
auto result_null_map = ColumnUInt8::create(input_rows_count, 0);
725+
auto result_data_column = ColumnInt64::create(input_rows_count);
726+
auto& result_data = result_data_column->get_data();
727+
728+
bool col_const[2];
729+
ColumnPtr argument_columns[2];
730+
for (int i = 0; i < 2; ++i) {
731+
col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column);
732+
}
733+
argument_columns[0] = col_const[0] ? static_cast<const ColumnConst&>(
734+
*block.get_by_position(arguments[0]).column)
735+
.convert_to_full_column()
736+
: block.get_by_position(arguments[0]).column;
737+
default_preprocess_parameter_columns(argument_columns, col_const, {1}, block, arguments);
738+
739+
if (col_const[1]) {
740+
RegexpCountImpl::execute_impl_const_args(context, argument_columns, input_rows_count,
741+
result_data, result_null_map->get_data());
742+
} else {
743+
RegexpCountImpl::execute_impl(context, argument_columns, input_rows_count, result_data,
744+
result_null_map->get_data());
745+
}
746+
747+
block.get_by_position(result).column =
748+
ColumnNullable::create(std::move(result_data_column), std::move(result_null_map));
749+
return Status::OK();
750+
}
751+
};
752+
603753
void register_function_regexp_extract(SimpleFunctionFactory& factory) {
604754
factory.register_function<FunctionRegexpReplace<RegexpReplaceImpl, ThreeParamTypes>>();
605755
factory.register_function<FunctionRegexpReplace<RegexpReplaceImpl, FourParamTypes>>();
@@ -608,6 +758,7 @@ void register_function_regexp_extract(SimpleFunctionFactory& factory) {
608758
factory.register_function<FunctionRegexpFunctionality<RegexpExtractImpl<true>>>();
609759
factory.register_function<FunctionRegexpFunctionality<RegexpExtractImpl<false>>>();
610760
factory.register_function<FunctionRegexpFunctionality<RegexpExtractAllImpl>>();
761+
factory.register_function<FunctionRegexpCount>();
611762
}
612763

613764
} // namespace doris::vectorized

fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@
361361
import org.apache.doris.nereids.trees.expressions.functions.scalar.Radians;
362362
import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
363363
import org.apache.doris.nereids.trees.expressions.functions.scalar.RandomBytes;
364+
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpCount;
364365
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract;
365366
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAll;
366367
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractOrNull;
@@ -863,6 +864,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
863864
scalar(Radians.class, "radians"),
864865
scalar(Random.class, "rand", "random"),
865866
scalar(Regexp.class, "regexp"),
867+
scalar(RegexpCount.class, "regexp_count"),
866868
scalar(RegexpExtract.class, "regexp_extract"),
867869
scalar(RegexpExtractAll.class, "regexp_extract_all"),
868870
scalar(RegexpExtractOrNull.class, "regexp_extract_or_null"),
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.scalar;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.nereids.trees.expressions.Expression;
22+
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
23+
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
24+
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
25+
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
26+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
27+
import org.apache.doris.nereids.types.BigIntType;
28+
import org.apache.doris.nereids.types.StringType;
29+
import org.apache.doris.nereids.types.VarcharType;
30+
31+
import com.google.common.base.Preconditions;
32+
import com.google.common.collect.ImmutableList;
33+
34+
import java.util.List;
35+
36+
/**
37+
* ScalarFunction 'regexp_count'. This class is generated by GenerateFunction.
38+
*/
39+
public class RegexpCount extends ScalarFunction
40+
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable, PropagateNullLiteral {
41+
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
42+
FunctionSignature.ret(BigIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
43+
FunctionSignature.ret(BigIntType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE)
44+
);
45+
46+
/**
47+
* constructor with 2 arguments.
48+
*/
49+
public RegexpCount(Expression arg0, Expression arg1) {
50+
super("regexp_count", arg0, arg1);
51+
}
52+
53+
/**
54+
* withChildren.
55+
*/
56+
@Override
57+
public RegexpCount withChildren(List<Expression> children) {
58+
Preconditions.checkArgument(children.size() == 2);
59+
return new RegexpCount(children.get(0), children.get(1));
60+
}
61+
62+
@Override
63+
public List<FunctionSignature> getSignatures() {
64+
return SIGNATURES;
65+
}
66+
67+
@Override
68+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
69+
return visitor.visitRegexpCount(this, context);
70+
}
71+
}
72+
73+
74+
75+
76+
77+
78+
79+
80+
81+
82+
83+
84+
85+
86+
87+
88+
89+
90+
91+
92+
93+
94+
95+
96+

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@
361361
import org.apache.doris.nereids.trees.expressions.functions.scalar.Radians;
362362
import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
363363
import org.apache.doris.nereids.trees.expressions.functions.scalar.RandomBytes;
364+
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpCount;
364365
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract;
365366
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAll;
366367
import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractOrNull;
@@ -1859,6 +1860,10 @@ default R visitRegexpReplaceOne(RegexpReplaceOne regexpReplaceOne, C context) {
18591860
return visitScalarFunction(regexpReplaceOne, context);
18601861
}
18611862

1863+
default R visitRegexpCount(RegexpCount regexpCount, C context) {
1864+
return visitScalarFunction(regexpCount, context);
1865+
}
1866+
18621867
default R visitRepeat(Repeat repeat, C context) {
18631868
return visitScalarFunction(repeat, context);
18641869
}

0 commit comments

Comments
 (0)