Skip to content

Commit b486ba0

Browse files
authored
rpc : add rpc_msg_set_tensor_hash_req (#13353)
* rpc : add rpc_msg_set_tensor_hash_req Use a dedicated struct for the request of RPC_CMD_SET_TENSOR_HASH which makes the code cleaner. * fix
1 parent 02115dc commit b486ba0

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req {
151151
uint8_t value;
152152
};
153153

154+
struct rpc_msg_set_tensor_hash_req {
155+
rpc_tensor tensor;
156+
uint64_t offset;
157+
uint64_t hash;
158+
};
159+
154160
struct rpc_msg_set_tensor_hash_rsp {
155161
uint8_t result;
156162
};
@@ -548,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
548554
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
549555
rpc_tensor rpc_tensor = serialize_tensor(tensor);
550556
if (size > HASH_THRESHOLD) {
551-
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
552-
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
553-
std::vector<uint8_t> input(input_size, 0);
554-
uint64_t hash = fnv_hash((const uint8_t*)data, size);
555-
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
556-
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
557-
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
557+
rpc_msg_set_tensor_hash_req request;
558+
request.tensor = rpc_tensor;
559+
request.offset = offset;
560+
request.hash = fnv_hash((const uint8_t*)data, size);
558561
rpc_msg_set_tensor_hash_rsp response;
559-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
562+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
560563
GGML_ASSERT(status);
561564
if (response.result) {
562565
// the server has the same data, no need to send it
@@ -864,7 +867,7 @@ class rpc_server {
864867
bool free_buffer(const rpc_msg_free_buffer_req & request);
865868
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
866869
bool set_tensor(const std::vector<uint8_t> & input);
867-
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
870+
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
868871
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
869872
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
870873
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
@@ -1101,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
11011104
return true;
11021105
}
11031106

1104-
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
1107+
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
11051108
{
1106-
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1107-
if (input.size() != sizeof(rpc_tensor) + 16) {
1108-
return false;
1109-
}
1110-
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1111-
uint64_t offset;
1112-
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1113-
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
11141109
std::vector<uint8_t> cached_file;
1115-
if (!get_cached_file(*hash, cached_file)) {
1110+
if (!get_cached_file(request.hash, cached_file)) {
11161111
response.result = 0;
11171112
return true;
11181113
}
@@ -1125,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
11251120
ggml_context_ptr ctx_ptr { ggml_init(params) };
11261121
GGML_ASSERT(ctx_ptr != nullptr);
11271122
ggml_context * ctx = ctx_ptr.get();
1128-
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1123+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
11291124
if (tensor == nullptr) {
11301125
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
11311126
return false;
11321127
}
1133-
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
1128+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1129+
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
11341130

11351131
// sanitize tensor->data
11361132
{
11371133
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
11381134
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
11391135

1140-
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1136+
if (request.tensor.data + request.offset < p0
1137+
|| request.tensor.data + request.offset >= p1
1138+
|| size > (p1 - request.tensor.data - request.offset)) {
11411139
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1142-
__func__, in_tensor->data, offset, size, *hash, p0, p1);
1140+
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
11431141
return false;
11441142
}
11451143
}
1146-
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
1144+
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
11471145
response.result = 1;
11481146
return true;
11491147
}
@@ -1503,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
15031501
break;
15041502
}
15051503
case RPC_CMD_SET_TENSOR_HASH: {
1506-
std::vector<uint8_t> input;
1507-
if (!recv_msg(sockfd, input)) {
1504+
rpc_msg_set_tensor_hash_req request;
1505+
if (!recv_msg(sockfd, &request, sizeof(request))) {
15081506
return;
15091507
}
15101508
rpc_msg_set_tensor_hash_rsp response;
1511-
if (!server.set_tensor_hash(input, response)) {
1509+
if (!server.set_tensor_hash(request, response)) {
15121510
return;
15131511
}
15141512
if (!send_msg(sockfd, &response, sizeof(response))) {

0 commit comments

Comments
 (0)