@@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req {
151
151
uint8_t value;
152
152
};
153
153
154
+ struct rpc_msg_set_tensor_hash_req {
155
+ rpc_tensor tensor;
156
+ uint64_t offset;
157
+ uint64_t hash;
158
+ };
159
+
154
160
struct rpc_msg_set_tensor_hash_rsp {
155
161
uint8_t result;
156
162
};
@@ -548,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
548
554
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
549
555
rpc_tensor rpc_tensor = serialize_tensor (tensor);
550
556
if (size > HASH_THRESHOLD) {
551
- // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
552
- size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + sizeof (uint64_t );
553
- std::vector<uint8_t > input (input_size, 0 );
554
- uint64_t hash = fnv_hash ((const uint8_t *)data, size);
555
- memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
556
- memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
557
- memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &hash, sizeof (hash));
557
+ rpc_msg_set_tensor_hash_req request;
558
+ request.tensor = rpc_tensor;
559
+ request.offset = offset;
560
+ request.hash = fnv_hash ((const uint8_t *)data, size);
558
561
rpc_msg_set_tensor_hash_rsp response;
559
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, input. data (), input. size ( ), &response, sizeof (response));
562
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, &request, sizeof (request ), &response, sizeof (response));
560
563
GGML_ASSERT (status);
561
564
if (response.result ) {
562
565
// the server has the same data, no need to send it
@@ -864,7 +867,7 @@ class rpc_server {
864
867
bool free_buffer (const rpc_msg_free_buffer_req & request);
865
868
bool buffer_clear (const rpc_msg_buffer_clear_req & request);
866
869
bool set_tensor (const std::vector<uint8_t > & input);
867
- bool set_tensor_hash (const std::vector< uint8_t > & input , rpc_msg_set_tensor_hash_rsp & response);
870
+ bool set_tensor_hash (const rpc_msg_set_tensor_hash_req & request , rpc_msg_set_tensor_hash_rsp & response);
868
871
bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
869
872
bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
870
873
bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
@@ -1101,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1101
1104
return true ;
1102
1105
}
1103
1106
1104
- bool rpc_server::set_tensor_hash (const std::vector< uint8_t > & input , rpc_msg_set_tensor_hash_rsp & response)
1107
+ bool rpc_server::set_tensor_hash (const rpc_msg_set_tensor_hash_req & request , rpc_msg_set_tensor_hash_rsp & response)
1105
1108
{
1106
- // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1107
- if (input.size () != sizeof (rpc_tensor) + 16 ) {
1108
- return false ;
1109
- }
1110
- const rpc_tensor * in_tensor = (const rpc_tensor *)input.data ();
1111
- uint64_t offset;
1112
- memcpy (&offset, input.data () + sizeof (rpc_tensor), sizeof (offset));
1113
- const uint64_t * hash = (const uint64_t *)(input.data () + sizeof (rpc_tensor) + sizeof (offset));
1114
1109
std::vector<uint8_t > cached_file;
1115
- if (!get_cached_file (* hash, cached_file)) {
1110
+ if (!get_cached_file (request. hash , cached_file)) {
1116
1111
response.result = 0 ;
1117
1112
return true ;
1118
1113
}
@@ -1125,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
1125
1120
ggml_context_ptr ctx_ptr { ggml_init (params) };
1126
1121
GGML_ASSERT (ctx_ptr != nullptr );
1127
1122
ggml_context * ctx = ctx_ptr.get ();
1128
- ggml_tensor * tensor = deserialize_tensor (ctx, in_tensor );
1123
+ ggml_tensor * tensor = deserialize_tensor (ctx, &request. tensor );
1129
1124
if (tensor == nullptr ) {
1130
1125
GGML_LOG_ERROR (" [%s] error deserializing tensor\n " , __func__);
1131
1126
return false ;
1132
1127
}
1133
- GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " , __func__, (void *)tensor->buffer , tensor->data , offset, size, *hash);
1128
+ GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " ,
1129
+ __func__, (void *)tensor->buffer , tensor->data , request.offset , size, request.hash );
1134
1130
1135
1131
// sanitize tensor->data
1136
1132
{
1137
1133
const size_t p0 = (size_t ) ggml_backend_buffer_get_base (tensor->buffer );
1138
1134
const size_t p1 = p0 + ggml_backend_buffer_get_size (tensor->buffer );
1139
1135
1140
- if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1136
+ if (request.tensor .data + request.offset < p0
1137
+ || request.tensor .data + request.offset >= p1
1138
+ || size > (p1 - request.tensor .data - request.offset )) {
1141
1139
GGML_LOG_ERROR (" [%s] tensor data region (data=0x%" PRIx64 " , offset=%" PRIu64 " , size=%zu, hash=0x%" PRIx64 " ) out of buffer bounds [0x%zx, 0x%zx)\n " ,
1142
- __func__, in_tensor-> data , offset, size, * hash, p0, p1);
1140
+ __func__, request. tensor . data , request. offset , size, request. hash , p0, p1);
1143
1141
return false ;
1144
1142
}
1145
1143
}
1146
- ggml_backend_tensor_set (tensor, cached_file.data (), offset, size);
1144
+ ggml_backend_tensor_set (tensor, cached_file.data (), request. offset , size);
1147
1145
response.result = 1 ;
1148
1146
return true ;
1149
1147
}
@@ -1503,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1503
1501
break ;
1504
1502
}
1505
1503
case RPC_CMD_SET_TENSOR_HASH: {
1506
- std::vector< uint8_t > input ;
1507
- if (!recv_msg (sockfd, input )) {
1504
+ rpc_msg_set_tensor_hash_req request ;
1505
+ if (!recv_msg (sockfd, &request, sizeof (request) )) {
1508
1506
return ;
1509
1507
}
1510
1508
rpc_msg_set_tensor_hash_rsp response;
1511
- if (!server.set_tensor_hash (input , response)) {
1509
+ if (!server.set_tensor_hash (request , response)) {
1512
1510
return ;
1513
1511
}
1514
1512
if (!send_msg (sockfd, &response, sizeof (response))) {
0 commit comments