43
43
#include " vec/core/types.h"
44
44
#include " vec/data_types/data_type.h"
45
45
#include " vec/data_types/data_type_nullable.h"
46
+ #include " vec/data_types/data_type_number.h"
46
47
#include " vec/data_types/data_type_string.h"
47
48
#include " vec/functions/function.h"
48
49
#include " vec/functions/simple_function_factory.h"
@@ -501,6 +502,84 @@ struct RegexpExtractAllImpl {
501
502
}
502
503
};
503
504
505
+ struct RegexpCountImpl {
506
+ static constexpr auto name = " regexp_count" ;
507
+
508
+ static void execute_impl (FunctionContext* context, ColumnPtr argument_columns[],
509
+ size_t input_rows_count, ColumnInt64::Container& result_data,
510
+ NullMap& null_map) {
511
+ const auto * str_col = check_and_get_column<ColumnString>(argument_columns[0 ].get ());
512
+ const auto * pattern_col = check_and_get_column<ColumnString>(argument_columns[1 ].get ());
513
+ for (int i = 0 ; i < input_rows_count; ++i) {
514
+ if (null_map[i]) {
515
+ result_data[i] = 0 ;
516
+ continue ;
517
+ }
518
+ result_data[i] = _execute_inner_loop<false >(context, str_col, pattern_col, null_map, i);
519
+ }
520
+ }
521
+
522
+ static void execute_impl_const_args (FunctionContext* context, ColumnPtr argument_columns[],
523
+ size_t input_rows_count,
524
+ ColumnInt64::Container& result_data, NullMap& null_map) {
525
+ const auto * str_col = check_and_get_column<ColumnString>(argument_columns[0 ].get ());
526
+ const auto * pattern_col = check_and_get_column<ColumnString>(argument_columns[1 ].get ());
527
+ for (int i = 0 ; i < input_rows_count; ++i) {
528
+ if (null_map[i]) {
529
+ result_data[i] = 0 ;
530
+ continue ;
531
+ }
532
+ result_data[i] = _execute_inner_loop<true >(context, str_col, pattern_col, null_map, i);
533
+ }
534
+ }
535
+
536
+ template <bool Const>
537
+ static int64_t _execute_inner_loop (FunctionContext* context, const ColumnString* str_col,
538
+ const ColumnString* pattern_col, NullMap& null_map,
539
+ const size_t index_now) {
540
+ re2::RE2* re = reinterpret_cast <re2::RE2*>(
541
+ context->get_function_state (FunctionContext::THREAD_LOCAL));
542
+ std::unique_ptr<re2::RE2> scoped_re;
543
+ if (re == nullptr ) {
544
+ std::string error_str;
545
+ const auto & pattern = pattern_col->get_data_at (index_check_const (index_now, Const));
546
+ bool st = StringFunctions::compile_regex (pattern, &error_str, StringRef (), StringRef (),
547
+ scoped_re);
548
+ if (!st) {
549
+ context->add_warning (error_str.c_str ());
550
+ null_map[index_now] = 1 ;
551
+ return 0 ;
552
+ }
553
+ re = scoped_re.get ();
554
+ }
555
+
556
+ const auto & str = str_col->get_data_at (index_now);
557
+
558
+ int64_t count = 0 ;
559
+ size_t pos = 0 ;
560
+ while (pos < str.size ) {
561
+ auto str_pos = str.data + pos;
562
+ auto str_size = str.size - pos;
563
+ re2::StringPiece str_sp_current = re2::StringPiece (str_pos, str_size);
564
+ re2::StringPiece match;
565
+
566
+ bool success = re->Match (str_sp_current, 0 , str_size, re2::RE2::UNANCHORED, &match, 1 );
567
+ if (!success) {
568
+ break ;
569
+ }
570
+ if (match.empty ()) {
571
+ pos += 1 ;
572
+ continue ;
573
+ }
574
+ count++;
575
+ size_t match_start = match.data () - str_sp_current.data ();
576
+ pos += match_start + match.size ();
577
+ }
578
+
579
+ return count;
580
+ }
581
+ };
582
+
504
583
// template FunctionRegexpFunctionality is used for regexp_xxxx series functions, not for regexp match.
505
584
template <typename Impl>
506
585
class FunctionRegexpFunctionality : public IFunction {
@@ -600,6 +679,77 @@ class FunctionRegexpFunctionality : public IFunction {
600
679
}
601
680
};
602
681
682
+ class FunctionRegexpCount : public IFunction {
683
+ public:
684
+ static constexpr auto name = " regexp_count" ;
685
+
686
+ static FunctionPtr create () { return std::make_shared<FunctionRegexpCount>(); }
687
+
688
+ String get_name () const override { return name; }
689
+
690
+ size_t get_number_of_arguments () const override { return 2 ; }
691
+
692
+ DataTypePtr get_return_type_impl (const DataTypes& arguments) const override {
693
+ DataTypePtr int64_type = std::make_shared<DataTypeInt64>();
694
+ return make_nullable (int64_type);
695
+ }
696
+
697
+ Status open (FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
698
+ if (scope == FunctionContext::THREAD_LOCAL) {
699
+ if (context->is_col_constant (1 )) {
700
+ DCHECK (!context->get_function_state (scope));
701
+ const auto pattern_col = context->get_constant_col (1 )->column_ptr ;
702
+ const auto & pattern = pattern_col->get_data_at (0 );
703
+ if (pattern.size == 0 ) {
704
+ return Status::OK ();
705
+ }
706
+
707
+ std::string error_str;
708
+ std::unique_ptr<re2::RE2> scoped_re;
709
+ bool st = StringFunctions::compile_regex (pattern, &error_str, StringRef (),
710
+ StringRef (), scoped_re);
711
+ if (!st) {
712
+ context->set_error (error_str.c_str ());
713
+ return Status::InvalidArgument (error_str);
714
+ }
715
+ std::shared_ptr<re2::RE2> re (scoped_re.release ());
716
+ context->set_function_state (scope, re);
717
+ }
718
+ }
719
+ return Status::OK ();
720
+ }
721
+
722
+ Status execute_impl (FunctionContext* context, Block& block, const ColumnNumbers& arguments,
723
+ uint32_t result, size_t input_rows_count) const override {
724
+ auto result_null_map = ColumnUInt8::create (input_rows_count, 0 );
725
+ auto result_data_column = ColumnInt64::create (input_rows_count);
726
+ auto & result_data = result_data_column->get_data ();
727
+
728
+ bool col_const[2 ];
729
+ ColumnPtr argument_columns[2 ];
730
+ for (int i = 0 ; i < 2 ; ++i) {
731
+ col_const[i] = is_column_const (*block.get_by_position (arguments[i]).column );
732
+ }
733
+ argument_columns[0 ] = col_const[0 ] ? static_cast <const ColumnConst&>(
734
+ *block.get_by_position (arguments[0 ]).column )
735
+ .convert_to_full_column ()
736
+ : block.get_by_position (arguments[0 ]).column ;
737
+ default_preprocess_parameter_columns (argument_columns, col_const, {1 }, block, arguments);
738
+
739
+ if (col_const[1 ]) {
740
+ RegexpCountImpl::execute_impl_const_args (context, argument_columns, input_rows_count,
741
+ result_data, result_null_map->get_data ());
742
+ } else {
743
+ RegexpCountImpl::execute_impl (context, argument_columns, input_rows_count, result_data,
744
+ result_null_map->get_data ());
745
+ }
746
+
747
+ block.get_by_position (result).column =
748
+ ColumnNullable::create (std::move (result_data_column), std::move (result_null_map));
749
+ return Status::OK ();
750
+ }
751
+ };
752
+
603
753
void register_function_regexp_extract (SimpleFunctionFactory& factory) {
604
754
factory.register_function <FunctionRegexpReplace<RegexpReplaceImpl, ThreeParamTypes>>();
605
755
factory.register_function <FunctionRegexpReplace<RegexpReplaceImpl, FourParamTypes>>();
@@ -608,6 +758,7 @@ void register_function_regexp_extract(SimpleFunctionFactory& factory) {
608
758
factory.register_function <FunctionRegexpFunctionality<RegexpExtractImpl<true >>>();
609
759
factory.register_function <FunctionRegexpFunctionality<RegexpExtractImpl<false >>>();
610
760
factory.register_function <FunctionRegexpFunctionality<RegexpExtractAllImpl>>();
761
+ factory.register_function <FunctionRegexpCount>();
611
762
}
612
763
613
764
} // namespace doris::vectorized
0 commit comments