Skip to content

Add user-defined checker for server side #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/call/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ use crate::codec::{DeserializeFn, SerializeFn};
use crate::cq::CompletionQueue;
use crate::error::{Error, Result};
use crate::metadata::Metadata;
use crate::server::ServerChecker;
use crate::server::{BoxHandler, RequestCallContext};
use crate::task::{BatchFuture, CallTag, Executor, Kicker};
use crate::CheckResult;

pub struct Deadline {
spec: gpr_timespec,
Expand Down Expand Up @@ -74,12 +76,13 @@ impl RequestContext {
cq: &CompletionQueue,
rc: &mut RequestCallContext,
) -> result::Result<(), Self> {
let checker = rc.get_checker();
let handler = unsafe { rc.get_handler(self.method()) };
match handler {
Some(handler) => match handler.method_type() {
MethodType::Unary | MethodType::ServerStreaming => Err(self),
_ => {
execute(self, cq, None, handler);
execute(self, cq, None, handler, checker);
Ok(())
}
},
Expand Down Expand Up @@ -225,9 +228,10 @@ impl UnaryRequestContext {
cq: &CompletionQueue,
reader: Option<MessageReader>,
) {
let checker = rc.get_checker();
let handler = unsafe { rc.get_handler(self.request.method()).unwrap() };
if reader.is_some() {
return execute(self.request, cq, reader, handler);
return execute(self.request, cq, reader, handler, checker);
}

let status = RpcStatus::new(RpcStatusCode::INTERNAL, Some("No payload".to_owned()));
Expand Down Expand Up @@ -775,7 +779,19 @@ fn execute(
cq: &CompletionQueue,
payload: Option<MessageReader>,
f: &mut BoxHandler,
mut checkers: Vec<Box<dyn ServerChecker>>,
) {
let rpc_ctx = RpcContext::new(ctx, cq);

for handler in checkers.iter_mut() {
match handler.check(&rpc_ctx) {
CheckResult::Continue => {}
CheckResult::Abort(status) => {
rpc_ctx.call().abort(&status);
return;
}
}
}

f.handle(rpc_ctx, payload)
}
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,6 @@ pub use crate::security::{
CertificateRequestType, ChannelCredentials, ChannelCredentialsBuilder, ServerCredentials,
ServerCredentialsBuilder, ServerCredentialsFetcher,
};
pub use crate::server::{Server, ServerBuilder, Service, ServiceBuilder, ShutdownFuture};
pub use crate::server::{
CheckResult, Server, ServerBuilder, ServerChecker, Service, ServiceBuilder, ShutdownFuture,
};
39 changes: 39 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::env::Environment;
use crate::error::{Error, Result};
use crate::task::{CallTag, CqFuture};
use crate::RpcContext;
use crate::RpcStatus;

const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;

Expand Down Expand Up @@ -266,6 +267,24 @@ impl ServiceBuilder {
}
}

/// Used to indicate the result of the check. If it returns `Abort`,
/// skip the subsequent checkers and abort the grpc call.
pub enum CheckResult {
Continue,
Abort(RpcStatus),
}

pub trait ServerChecker: Send {
fn check(&mut self, ctx: &RpcContext) -> CheckResult;
fn box_clone(&self) -> Box<dyn ServerChecker>;
}

impl Clone for Box<dyn ServerChecker> {
fn clone(&self) -> Self {
self.box_clone()
}
}

/// A gRPC service.
///
/// Use [`ServiceBuilder`] to build a [`Service`].
Expand All @@ -280,6 +299,7 @@ pub struct ServerBuilder {
args: Option<ChannelArgs>,
slots_per_cq: usize,
handlers: HashMap<&'static [u8], BoxHandler>,
checkers: Vec<Box<dyn ServerChecker>>,
}

impl ServerBuilder {
Expand All @@ -291,6 +311,7 @@ impl ServerBuilder {
args: None,
slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
handlers: HashMap::new(),
checkers: Vec::new(),
}
}

Expand Down Expand Up @@ -320,6 +341,16 @@ impl ServerBuilder {
self
}

/// Add a custom checker to handle some tasks before the grpc call handler starts.
/// This allows users to operate grpc call based on the context. Users can add
/// multiple checkers and they will be executed in the order added.
///
/// TODO: Extend this interface to intercepte each payload like grpc-c++.
pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder {
self.checkers.push(Box::new(checker));
self
}

/// Finalize the [`ServerBuilder`] and build the [`Server`].
pub fn build(mut self) -> Result<Server> {
let args = self
Expand Down Expand Up @@ -355,6 +386,7 @@ impl ServerBuilder {
slots_per_cq: self.slots_per_cq,
}),
handlers: self.handlers,
checkers: self.checkers,
})
}
}
Expand Down Expand Up @@ -439,6 +471,7 @@ pub type BoxHandler = Box<dyn CloneableHandler>;
pub struct RequestCallContext {
server: Arc<ServerCore>,
registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>,
checkers: Vec<Box<dyn ServerChecker>>,
}

impl RequestCallContext {
Expand All @@ -449,6 +482,10 @@ impl RequestCallContext {
let registry = &mut *self.registry.get();
registry.get_mut(path)
}

pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> {
self.checkers.clone()
}
}

// Apparently, its life time is guaranteed by the ref count, hence is safe to be sent
Expand Down Expand Up @@ -506,6 +543,7 @@ pub struct Server {
env: Arc<Environment>,
core: Arc<ServerCore>,
handlers: HashMap<&'static [u8], BoxHandler>,
checkers: Vec<Box<dyn ServerChecker>>,
}

impl Server {
Expand Down Expand Up @@ -549,6 +587,7 @@ impl Server {
let rc = RequestCallContext {
server: self.core.clone(),
registry: Arc::new(UnsafeCell::new(registry)),
checkers: self.checkers.clone(),
};
for _ in 0..self.core.slots_per_cq {
request_call(rc.clone(), cq);
Expand Down
52 changes: 52 additions & 0 deletions tests-and-examples/tests/cases/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,55 @@ fn test_shutdown_when_exists_grpc_call() {
"Send should get error because server is shutdown, so the grpc is cancelled."
);
}

#[test]
fn test_custom_checker_server_side() {
let flag = Arc::new(atomic::AtomicBool::new(false));
let checker = FlagChecker { flag: flag.clone() };

let env = Arc::new(Environment::new(2));
// Start a server and delay the process of grpc server.
let service = create_greeter(PeerService);
let mut server = ServerBuilder::new(env.clone())
.add_checker(checker)
.register_service(service)
.bind("127.0.0.1", 0)
.build()
.unwrap();
server.start();
let port = server.bind_addrs().next().unwrap().1;
let ch = ChannelBuilder::new(env).connect(&format!("127.0.0.1:{}", port));
let client = GreeterClient::new(ch);
let req = HelloRequest::default();

let _ = client.say_hello(&req).unwrap();
let _ = client.say_hello(&req).unwrap();

flag.store(true, Ordering::SeqCst);
assert_eq!(
client.say_hello(&req).unwrap_err().to_string(),
"RpcFailure: 15-DATA_LOSS ".to_owned()
);
}

#[derive(Clone)]
struct FlagChecker {
flag: Arc<atomic::AtomicBool>,
}

impl ServerChecker for FlagChecker {
fn check(&mut self, ctx: &RpcContext) -> CheckResult {
let method = String::from_utf8(ctx.method().to_owned());
assert_eq!(&method.unwrap(), "/helloworld.Greeter/SayHello");

if self.flag.load(Ordering::SeqCst) {
CheckResult::Abort(RpcStatus::new(RpcStatusCode::DATA_LOSS, None))
} else {
CheckResult::Continue
}
}

fn box_clone(&self) -> Box<dyn ServerChecker> {
Box::new(self.clone())
}
}