Skip to content

Commit aef8a77

Browse files
authored
Add user-defined checker for server side (#502)
Signed-off-by: Xintao <[email protected]>
1 parent 8dd16b4 commit aef8a77

File tree

4 files changed

+112
-3
lines changed

4 files changed

+112
-3
lines changed

src/call/server.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ use crate::codec::{DeserializeFn, SerializeFn};
2525
use crate::cq::CompletionQueue;
2626
use crate::error::{Error, Result};
2727
use crate::metadata::Metadata;
28+
use crate::server::ServerChecker;
2829
use crate::server::{BoxHandler, RequestCallContext};
2930
use crate::task::{BatchFuture, CallTag, Executor, Kicker};
31+
use crate::CheckResult;
3032

3133
pub struct Deadline {
3234
spec: gpr_timespec,
@@ -74,12 +76,13 @@ impl RequestContext {
7476
cq: &CompletionQueue,
7577
rc: &mut RequestCallContext,
7678
) -> result::Result<(), Self> {
79+
let checker = rc.get_checker();
7780
let handler = unsafe { rc.get_handler(self.method()) };
7881
match handler {
7982
Some(handler) => match handler.method_type() {
8083
MethodType::Unary | MethodType::ServerStreaming => Err(self),
8184
_ => {
82-
execute(self, cq, None, handler);
85+
execute(self, cq, None, handler, checker);
8386
Ok(())
8487
}
8588
},
@@ -225,9 +228,10 @@ impl UnaryRequestContext {
225228
cq: &CompletionQueue,
226229
reader: Option<MessageReader>,
227230
) {
231+
let checker = rc.get_checker();
228232
let handler = unsafe { rc.get_handler(self.request.method()).unwrap() };
229233
if reader.is_some() {
230-
return execute(self.request, cq, reader, handler);
234+
return execute(self.request, cq, reader, handler, checker);
231235
}
232236

233237
let status = RpcStatus::new(RpcStatusCode::INTERNAL, Some("No payload".to_owned()));
@@ -775,7 +779,19 @@ fn execute(
775779
cq: &CompletionQueue,
776780
payload: Option<MessageReader>,
777781
f: &mut BoxHandler,
782+
mut checkers: Vec<Box<dyn ServerChecker>>,
778783
) {
779784
let rpc_ctx = RpcContext::new(ctx, cq);
785+
786+
for handler in checkers.iter_mut() {
787+
match handler.check(&rpc_ctx) {
788+
CheckResult::Continue => {}
789+
CheckResult::Abort(status) => {
790+
rpc_ctx.call().abort(&status);
791+
return;
792+
}
793+
}
794+
}
795+
780796
f.handle(rpc_ctx, payload)
781797
}

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@ pub use crate::security::{
7777
CertificateRequestType, ChannelCredentials, ChannelCredentialsBuilder, ServerCredentials,
7878
ServerCredentialsBuilder, ServerCredentialsFetcher,
7979
};
80-
pub use crate::server::{Server, ServerBuilder, Service, ServiceBuilder, ShutdownFuture};
80+
pub use crate::server::{
81+
CheckResult, Server, ServerBuilder, ServerChecker, Service, ServiceBuilder, ShutdownFuture,
82+
};

src/server.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::env::Environment;
2121
use crate::error::{Error, Result};
2222
use crate::task::{CallTag, CqFuture};
2323
use crate::RpcContext;
24+
use crate::RpcStatus;
2425

2526
const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;
2627

@@ -266,6 +267,24 @@ impl ServiceBuilder {
266267
}
267268
}
268269

270+
/// Used to indicate the result of the check. If it returns `Abort`,
271+
/// skip the subsequent checkers and abort the grpc call.
272+
pub enum CheckResult {
273+
Continue,
274+
Abort(RpcStatus),
275+
}
276+
277+
pub trait ServerChecker: Send {
278+
fn check(&mut self, ctx: &RpcContext) -> CheckResult;
279+
fn box_clone(&self) -> Box<dyn ServerChecker>;
280+
}
281+
282+
impl Clone for Box<dyn ServerChecker> {
283+
fn clone(&self) -> Self {
284+
self.box_clone()
285+
}
286+
}
287+
269288
/// A gRPC service.
270289
///
271290
/// Use [`ServiceBuilder`] to build a [`Service`].
@@ -280,6 +299,7 @@ pub struct ServerBuilder {
280299
args: Option<ChannelArgs>,
281300
slots_per_cq: usize,
282301
handlers: HashMap<&'static [u8], BoxHandler>,
302+
checkers: Vec<Box<dyn ServerChecker>>,
283303
}
284304

285305
impl ServerBuilder {
@@ -291,6 +311,7 @@ impl ServerBuilder {
291311
args: None,
292312
slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
293313
handlers: HashMap::new(),
314+
checkers: Vec::new(),
294315
}
295316
}
296317

@@ -320,6 +341,16 @@ impl ServerBuilder {
320341
self
321342
}
322343

344+
/// Add a custom checker to handle some tasks before the grpc call handler starts.
345+
/// This allows users to operate grpc call based on the context. Users can add
346+
/// multiple checkers and they will be executed in the order added.
347+
///
348+
/// TODO: Extend this interface to intercepte each payload like grpc-c++.
349+
pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder {
350+
self.checkers.push(Box::new(checker));
351+
self
352+
}
353+
323354
/// Finalize the [`ServerBuilder`] and build the [`Server`].
324355
pub fn build(mut self) -> Result<Server> {
325356
let args = self
@@ -355,6 +386,7 @@ impl ServerBuilder {
355386
slots_per_cq: self.slots_per_cq,
356387
}),
357388
handlers: self.handlers,
389+
checkers: self.checkers,
358390
})
359391
}
360392
}
@@ -439,6 +471,7 @@ pub type BoxHandler = Box<dyn CloneableHandler>;
439471
pub struct RequestCallContext {
440472
server: Arc<ServerCore>,
441473
registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>,
474+
checkers: Vec<Box<dyn ServerChecker>>,
442475
}
443476

444477
impl RequestCallContext {
@@ -449,6 +482,10 @@ impl RequestCallContext {
449482
let registry = &mut *self.registry.get();
450483
registry.get_mut(path)
451484
}
485+
486+
pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> {
487+
self.checkers.clone()
488+
}
452489
}
453490

454491
// Apparently, its life time is guaranteed by the ref count, hence is safe to be sent
@@ -506,6 +543,7 @@ pub struct Server {
506543
env: Arc<Environment>,
507544
core: Arc<ServerCore>,
508545
handlers: HashMap<&'static [u8], BoxHandler>,
546+
checkers: Vec<Box<dyn ServerChecker>>,
509547
}
510548

511549
impl Server {
@@ -549,6 +587,7 @@ impl Server {
549587
let rc = RequestCallContext {
550588
server: self.core.clone(),
551589
registry: Arc::new(UnsafeCell::new(registry)),
590+
checkers: self.checkers.clone(),
552591
};
553592
for _ in 0..self.core.slots_per_cq {
554593
request_call(rc.clone(), cq);

tests-and-examples/tests/cases/misc.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,55 @@ fn test_shutdown_when_exists_grpc_call() {
221221
"Send should get error because server is shutdown, so the grpc is cancelled."
222222
);
223223
}
224+
225+
#[test]
226+
fn test_custom_checker_server_side() {
227+
let flag = Arc::new(atomic::AtomicBool::new(false));
228+
let checker = FlagChecker { flag: flag.clone() };
229+
230+
let env = Arc::new(Environment::new(2));
231+
// Start a server and delay the process of grpc server.
232+
let service = create_greeter(PeerService);
233+
let mut server = ServerBuilder::new(env.clone())
234+
.add_checker(checker)
235+
.register_service(service)
236+
.bind("127.0.0.1", 0)
237+
.build()
238+
.unwrap();
239+
server.start();
240+
let port = server.bind_addrs().next().unwrap().1;
241+
let ch = ChannelBuilder::new(env).connect(&format!("127.0.0.1:{}", port));
242+
let client = GreeterClient::new(ch);
243+
let req = HelloRequest::default();
244+
245+
let _ = client.say_hello(&req).unwrap();
246+
let _ = client.say_hello(&req).unwrap();
247+
248+
flag.store(true, Ordering::SeqCst);
249+
assert_eq!(
250+
client.say_hello(&req).unwrap_err().to_string(),
251+
"RpcFailure: 15-DATA_LOSS ".to_owned()
252+
);
253+
}
254+
255+
#[derive(Clone)]
256+
struct FlagChecker {
257+
flag: Arc<atomic::AtomicBool>,
258+
}
259+
260+
impl ServerChecker for FlagChecker {
261+
fn check(&mut self, ctx: &RpcContext) -> CheckResult {
262+
let method = String::from_utf8(ctx.method().to_owned());
263+
assert_eq!(&method.unwrap(), "/helloworld.Greeter/SayHello");
264+
265+
if self.flag.load(Ordering::SeqCst) {
266+
CheckResult::Abort(RpcStatus::new(RpcStatusCode::DATA_LOSS, None))
267+
} else {
268+
CheckResult::Continue
269+
}
270+
}
271+
272+
fn box_clone(&self) -> Box<dyn ServerChecker> {
273+
Box::new(self.clone())
274+
}
275+
}

0 commit comments

Comments
 (0)