Skip to content

Commit 30066a7

Browse files
authored
Add support for gRPC calls. (#100)
Signed-off-by: Rei Shimizu <[email protected]>
1 parent a30f30c commit 30066a7

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

src/dispatcher.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ pub(crate) fn register_callout(token_id: u32) {
3838
DISPATCHER.with(|dispatcher| dispatcher.register_callout(token_id));
3939
}
4040

41+
pub(crate) fn register_grpc_callout(token_id: u32) {
42+
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id));
43+
}
44+
4145
struct NoopRoot;
4246

4347
impl Context for NoopRoot {}
@@ -52,6 +56,7 @@ struct Dispatcher {
5256
http_streams: RefCell<HashMap<u32, Box<dyn HttpContext>>>,
5357
active_id: Cell<u32>,
5458
callouts: RefCell<HashMap<u32, u32>>,
59+
grpc_callouts: RefCell<HashMap<u32, u32>>,
5560
}
5661

5762
impl Dispatcher {
@@ -65,6 +70,7 @@ impl Dispatcher {
6570
http_streams: RefCell::new(HashMap::new()),
6671
active_id: Cell::new(0),
6772
callouts: RefCell::new(HashMap::new()),
73+
grpc_callouts: RefCell::new(HashMap::new()),
6874
}
6975
}
7076

@@ -91,6 +97,17 @@ impl Dispatcher {
9197
}
9298
}
9399

100+
fn register_grpc_callout(&self, token_id: u32) {
101+
if self
102+
.grpc_callouts
103+
.borrow_mut()
104+
.insert(token_id, self.active_id.get())
105+
.is_some()
106+
{
107+
panic!("duplicate token_id")
108+
}
109+
}
110+
94111
fn create_root_context(&self, context_id: u32) {
95112
let new_context = match self.new_root.get() {
96113
Some(f) => f(context_id),
@@ -381,6 +398,50 @@ impl Dispatcher {
381398
root.on_http_call_response(token_id, num_headers, body_size, num_trailers)
382399
}
383400
}
401+
402+
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
403+
let context_id = self
404+
.grpc_callouts
405+
.borrow_mut()
406+
.remove(&token_id)
407+
.expect("invalid token_id");
408+
409+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
410+
self.active_id.set(context_id);
411+
hostcalls::set_effective_context(context_id).unwrap();
412+
http_stream.on_grpc_call_response(token_id, 0, response_size);
413+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
414+
self.active_id.set(context_id);
415+
hostcalls::set_effective_context(context_id).unwrap();
416+
stream.on_grpc_call_response(token_id, 0, response_size);
417+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
418+
self.active_id.set(context_id);
419+
hostcalls::set_effective_context(context_id).unwrap();
420+
root.on_grpc_call_response(token_id, 0, response_size);
421+
}
422+
}
423+
424+
fn on_grpc_close(&self, token_id: u32, status_code: u32) {
425+
let context_id = self
426+
.grpc_callouts
427+
.borrow_mut()
428+
.remove(&token_id)
429+
.expect("invalid token_id");
430+
431+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
432+
self.active_id.set(context_id);
433+
hostcalls::set_effective_context(context_id).unwrap();
434+
http_stream.on_grpc_call_response(token_id, status_code, 0);
435+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
436+
self.active_id.set(context_id);
437+
hostcalls::set_effective_context(context_id).unwrap();
438+
stream.on_grpc_call_response(token_id, status_code, 0);
439+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
440+
self.active_id.set(context_id);
441+
hostcalls::set_effective_context(context_id).unwrap();
442+
root.on_grpc_call_response(token_id, status_code, 0);
443+
}
444+
}
384445
}
385446

386447
#[no_mangle]
@@ -509,3 +570,13 @@ pub extern "C" fn proxy_on_http_call_response(
509570
dispatcher.on_http_call_response(token_id, num_headers, body_size, num_trailers)
510571
})
511572
}
573+
574+
#[no_mangle]
575+
pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) {
576+
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size))
577+
}
578+
579+
#[no_mangle]
580+
pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) {
581+
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code))
582+
}

src/hostcalls.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,73 @@ pub fn dispatch_http_call(
651651
}
652652
}
653653

654+
extern "C" {
655+
fn proxy_grpc_call(
656+
upstream_data: *const u8,
657+
upstream_size: usize,
658+
service_name_data: *const u8,
659+
service_name_size: usize,
660+
method_name_data: *const u8,
661+
method_name_size: usize,
662+
initial_metadata_data: *const u8,
663+
initial_metadata_size: usize,
664+
message_data_data: *const u8,
665+
message_data_size: usize,
666+
timeout: u32,
667+
return_callout_id: *mut u32,
668+
) -> Status;
669+
}
670+
671+
pub fn dispatch_grpc_call(
672+
upstream_name: &str,
673+
service_name: &str,
674+
method_name: &str,
675+
initial_metadata: Vec<(&str, &[u8])>,
676+
message: Option<&[u8]>,
677+
timeout: Duration,
678+
) -> Result<u32, Status> {
679+
let mut return_callout_id = 0;
680+
let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata);
681+
unsafe {
682+
match proxy_grpc_call(
683+
upstream_name.as_ptr(),
684+
upstream_name.len(),
685+
service_name.as_ptr(),
686+
service_name.len(),
687+
method_name.as_ptr(),
688+
method_name.len(),
689+
serialized_initial_metadata.as_ptr(),
690+
serialized_initial_metadata.len(),
691+
message.map_or(null(), |message| message.as_ptr()),
692+
message.map_or(0, |message| message.len()),
693+
timeout.as_millis() as u32,
694+
&mut return_callout_id,
695+
) {
696+
Status::Ok => {
697+
dispatcher::register_grpc_callout(return_callout_id);
698+
Ok(return_callout_id)
699+
}
700+
Status::ParseFailure => Err(Status::ParseFailure),
701+
Status::InternalFailure => Err(Status::InternalFailure),
702+
status => panic!("unexpected status: {}", status as u32),
703+
}
704+
}
705+
}
706+
707+
extern "C" {
708+
fn proxy_grpc_cancel(token_id: u32) -> Status;
709+
}
710+
711+
pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> {
712+
unsafe {
713+
match proxy_grpc_cancel(token_id) {
714+
Status::Ok => Ok(()),
715+
Status::NotFound => Err(Status::NotFound),
716+
status => panic!("unexpected status: {}", status as u32),
717+
}
718+
}
719+
}
720+
654721
extern "C" {
655722
fn proxy_set_effective_context(context_id: u32) -> Status;
656723
}
@@ -783,6 +850,26 @@ mod utils {
783850
bytes
784851
}
785852

853+
pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes {
854+
let mut size: usize = 4;
855+
for (name, value) in &map {
856+
size += name.len() + value.len() + 10;
857+
}
858+
let mut bytes: Bytes = Vec::with_capacity(size);
859+
bytes.extend_from_slice(&map.len().to_le_bytes());
860+
for (name, value) in &map {
861+
bytes.extend_from_slice(&name.len().to_le_bytes());
862+
bytes.extend_from_slice(&value.len().to_le_bytes());
863+
}
864+
for (name, value) in &map {
865+
bytes.extend_from_slice(&name.as_bytes());
866+
bytes.push(0);
867+
bytes.extend_from_slice(&value);
868+
bytes.push(0);
869+
}
870+
bytes
871+
}
872+
786873
pub(super) fn deserialize_map(bytes: &[u8]) -> Vec<(String, String)> {
787874
let mut map = Vec::new();
788875
if bytes.is_empty() {

src/traits.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,35 @@ pub trait Context {
9090
hostcalls::get_map(MapType::HttpCallResponseTrailers).unwrap()
9191
}
9292

93+
fn dispatch_grpc_call(
94+
&self,
95+
upstream_name: &str,
96+
service_name: &str,
97+
method_name: &str,
98+
initial_metadata: Vec<(&str, &[u8])>,
99+
message: Option<&[u8]>,
100+
timeout: Duration,
101+
) -> Result<u32, Status> {
102+
hostcalls::dispatch_grpc_call(
103+
upstream_name,
104+
service_name,
105+
method_name,
106+
initial_metadata,
107+
message,
108+
timeout,
109+
)
110+
}
111+
112+
fn on_grpc_call_response(&mut self, _token_id: u32, _status_code: u32, _response_size: usize) {}
113+
114+
fn get_grpc_call_response_body(&self, start: usize, max_size: usize) -> Option<Bytes> {
115+
hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, start, max_size).unwrap()
116+
}
117+
118+
fn cancel_grpc_call(&self, token_id: u32) -> Result<(), Status> {
119+
hostcalls::cancel_grpc_call(token_id)
120+
}
121+
93122
fn on_done(&mut self) -> bool {
94123
true
95124
}

src/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub enum Status {
4242
Ok = 0,
4343
NotFound = 1,
4444
BadArgument = 2,
45+
ParseFailure = 4,
4546
Empty = 7,
4647
CasMismatch = 8,
4748
InternalFailure = 10,
@@ -62,6 +63,7 @@ pub enum BufferType {
6263
DownstreamData = 2,
6364
UpstreamData = 3,
6465
HttpCallResponseBody = 4,
66+
GrpcReceiveBuffer = 5,
6567
}
6668

6769
#[repr(u32)]

0 commit comments

Comments
 (0)