| 1 | // SPDX-License-Identifier: GPL-3.0-or-later |
| 2 | |
| 3 | #include "librpc/rpc_server.h" |
| 4 | |
| 5 | #include "librpc/internal.h" |
| 6 | #include "librpc/rpc.h" |
| 7 | |
| 8 | #include <libipc/ipc.h> |
| 9 | #include <mos/types.h> |
| 10 | #include <pb.h> |
| 11 | #include <pb_decode.h> |
| 12 | #include <pb_encode.h> |
| 13 | |
| 14 | #if defined(__MOS_KERNEL__) |
| 15 | #include <mos/lib/sync/mutex.hpp> |
| 16 | #include <mos_stdio.hpp> |
| 17 | #include <mos_stdlib.hpp> |
| 18 | #include <mos_string.hpp> |
| 19 | #else |
| 20 | #include <assert.h> |
| 21 | #include <errno.h> |
| 22 | #include <stdio.h> |
| 23 | #include <stdlib.h> |
| 24 | #include <string.h> |
| 25 | #define MOS_LIB_ASSERT_X(cond, msg) assert(cond &&msg) |
| 26 | #define MOS_LIB_ASSERT(cond) assert(cond) |
| 27 | #endif |
| 28 | |
| 29 | #ifdef __MOS_KERNEL__ |
| 30 | #include "mos/io/io.hpp" |
| 31 | #include "mos/ipc/ipc_io.hpp" |
| 32 | #include "mos/tasks/kthread.hpp" |
| 33 | |
| 34 | #include <mos/syscall/decl.h> |
| 35 | #define syscall_ipc_create(server_name, max_pending) ipc_create(server_name, max_pending).get() |
| 36 | #define syscall_ipc_accept(server_fd) ipc_accept(server_fd)->ref() |
| 37 | #define syscall_ipc_connect(server_name, smh_size) ipc_connect(server_name, smh_size) |
| 38 | #define start_thread(name, func, arg) kthread_create(func, arg, name) |
| 39 | #define syscall_io_close(fd) fd->unref() |
| 40 | #else |
| 41 | #include <mos/syscall/usermode.h> |
| 42 | #endif |
| 43 | |
| 44 | #if !defined(__MOS_KERNEL__) |
| 45 | // fixup for hosted libc |
| 46 | #include <pthread.h> |
| 47 | typedef pthread_mutex_t mutex_t; |
| 48 | #define memzero(ptr, size) memset(ptr, 0, size) |
| 49 | #define mutex_acquire(mutex) pthread_mutex_lock(mutex) |
| 50 | #define mutex_release(mutex) pthread_mutex_unlock(mutex) |
| 51 | #define mos_warn(fmt, ...) fprintf(stderr, fmt "\n", ##__VA_ARGS__) |
| 52 | #define MOS_LIB_UNREACHABLE() __builtin_unreachable() |
| 53 | static void start_thread(const char *name, thread_entry_t entry, void *arg) |
| 54 | { |
| 55 | union |
| 56 | { |
| 57 | thread_entry_t entry; |
| 58 | void *(*func)(void *); |
| 59 | } u = { entry }; // to make the compiler happy |
| 60 | pthread_t thread; |
| 61 | pthread_create(&thread, NULL, u.func, arg); |
| 62 | pthread_setname_np(thread, name); |
| 63 | } |
| 64 | #endif |
| 65 | |
| 66 | #define RPC_SERVER_MAX_PENDING_CALLS 32 |
| 67 | |
| 68 | typedef struct _rpc_server |
| 69 | { |
| 70 | const char *server_name; |
| 71 | void *data; |
| 72 | ipcfd_t server_fd; |
| 73 | size_t functions_count; |
| 74 | rpc_function_info_t *functions; |
| 75 | rpc_server_on_connect_t on_connect; |
| 76 | rpc_server_on_disconnect_t on_disconnect; |
| 77 | } rpc_server_t; |
| 78 | |
| 79 | typedef struct _rpc_args_iter |
| 80 | { |
| 81 | size_t next_arg_index; |
| 82 | size_t next_arg_byte; |
| 83 | } rpc_args_iter_t; |
| 84 | |
| 85 | struct _rpc_reply_wrapper |
| 86 | { |
| 87 | rpc_response_t *response; // may be relocated by rpc_fill_result |
| 88 | }; |
| 89 | |
| 90 | struct _rpc_context |
| 91 | { |
| 92 | ipcfd_t client_fd; |
| 93 | rpc_server_t *server; |
| 94 | rpc_request_t *request; |
| 95 | rpc_response_t *response; |
| 96 | rpc_args_iter_t arg_iter; |
| 97 | void *data; |
| 98 | }; |
| 99 | |
| 100 | static inline rpc_function_info_t *rpc_server_get_function(rpc_server_t *server, u32 function_id) |
| 101 | { |
| 102 | for (size_t i = 0; i < server->functions_count; i++) |
| 103 | if (server->functions[i].function_id == function_id) |
| 104 | return &server->functions[i]; |
| 105 | return NULL; |
| 106 | } |
| 107 | |
| 108 | static void rpc_handle_client(void *arg) |
| 109 | { |
| 110 | rpc_context_t *context = (rpc_context_t *) arg; |
| 111 | |
| 112 | if (context->server->on_connect) |
| 113 | context->server->on_connect(context); |
| 114 | |
| 115 | while (true) |
| 116 | { |
| 117 | ipc_msg_t *const msg = ipc_read_msg(fd: context->client_fd); |
| 118 | if (!msg) |
| 119 | break; |
| 120 | |
| 121 | if (msg->size < sizeof(rpc_request_t)) |
| 122 | { |
| 123 | mos_warn("failed to read message from client" ); |
| 124 | ipc_msg_destroy(buffer: msg); |
| 125 | break; |
| 126 | } |
| 127 | |
| 128 | rpc_request_t *request = (rpc_request_t *) msg->data; |
| 129 | if (request->magic != RPC_REQUEST_MAGIC) |
| 130 | { |
| 131 | mos_warn("invalid magic in rpc request: %x" , request->magic); |
| 132 | ipc_msg_destroy(buffer: msg); |
| 133 | break; |
| 134 | } |
| 135 | |
| 136 | rpc_function_info_t *function = rpc_server_get_function(server: context->server, function_id: request->function_id); |
| 137 | if (!function) |
| 138 | { |
| 139 | mos_warn("invalid function id in rpc request: %d" , request->function_id); |
| 140 | ipc_msg_destroy(buffer: msg); |
| 141 | break; |
| 142 | } |
| 143 | |
| 144 | if (request->args_count > RPC_MAX_ARGS) |
| 145 | { |
| 146 | mos_warn("too many arguments in rpc request: %d" , request->args_count); |
| 147 | ipc_msg_destroy(buffer: msg); |
| 148 | break; |
| 149 | } |
| 150 | |
| 151 | if (request->args_count != function->args_count) |
| 152 | { |
| 153 | mos_warn("invalid number if arguments in rpc request, expected %d, got %d" , function->args_count, request->args_count); |
| 154 | ipc_msg_destroy(buffer: msg); |
| 155 | break; |
| 156 | } |
| 157 | |
| 158 | // check argument types |
| 159 | const char *argptr = request->args_array; |
| 160 | for (size_t i = 0; i < request->args_count; i++) |
| 161 | { |
| 162 | const rpc_arg_t *arg = (const rpc_arg_t *) argptr; |
| 163 | if (arg->magic != RPC_ARG_MAGIC) |
| 164 | { |
| 165 | mos_warn("invalid magic in rpc argument: %x" , arg->magic); |
| 166 | ipc_msg_destroy(buffer: msg); |
| 167 | break; |
| 168 | } |
| 169 | if (arg->argtype != function->args_type[i]) |
| 170 | { |
| 171 | mos_warn("invalid argument type in rpc request, expected %d, got %d" , function->args_type[i], arg->argtype); |
| 172 | ipc_msg_destroy(buffer: msg); |
| 173 | break; |
| 174 | } |
| 175 | argptr += sizeof(rpc_arg_t) + arg->size; |
| 176 | } |
| 177 | context->request = request; |
| 178 | context->response = NULL; |
| 179 | context->arg_iter = (rpc_args_iter_t) { .next_arg_index: 0 }; |
| 180 | |
| 181 | const rpc_result_code_t result = function->func(context); |
| 182 | |
| 183 | if (context->response == NULL) |
| 184 | { |
| 185 | context->response = (rpc_response_t *) malloc(size: sizeof(rpc_response_t)); |
| 186 | context->response->magic = RPC_RESPONSE_MAGIC; |
| 187 | context->response->call_id = request->call_id; |
| 188 | context->response->data_size = 0; |
| 189 | } |
| 190 | |
| 191 | context->response->result_code = result; |
| 192 | |
| 193 | const bool written = ipc_write_as_msg(fd: context->client_fd, data: (const char *) context->response, size: sizeof(rpc_response_t) + context->response->data_size); |
| 194 | |
| 195 | ipc_msg_destroy(buffer: msg); |
| 196 | free(ptr: context->response); |
| 197 | context->response = NULL, context->request = NULL, context->arg_iter = (rpc_args_iter_t) { .next_arg_index: 0 }; |
| 198 | |
| 199 | if (!written) |
| 200 | { |
| 201 | mos_warn("failed to write reply to client" ); |
| 202 | break; |
| 203 | } |
| 204 | } |
| 205 | |
| 206 | if (context->server->on_disconnect) |
| 207 | context->server->on_disconnect(context); |
| 208 | |
| 209 | syscall_io_close(context->client_fd); |
| 210 | free(ptr: context); |
| 211 | } |
| 212 | |
| 213 | rpc_server_t *rpc_server_create(const char *server_name, void *data) |
| 214 | { |
| 215 | rpc_server_t *server = (rpc_server_t *) malloc(size: sizeof(rpc_server_t)); |
| 216 | memzero(s: server, n: sizeof(rpc_server_t)); |
| 217 | server->server_name = server_name; |
| 218 | server->data = data; |
| 219 | #ifndef __MOS_KERNEL__ |
| 220 | server->server_fd = -1; |
| 221 | #endif |
| 222 | server->functions_count = 0; |
| 223 | server->functions = NULL; |
| 224 | server->server_fd = syscall_ipc_create(server_name, RPC_SERVER_MAX_PENDING_CALLS); |
| 225 | if (IS_ERR_VALUE(server->server_fd)) |
| 226 | { |
| 227 | #if !defined(__MOS_KERNEL__) |
| 228 | errno = -server->server_fd; |
| 229 | #endif |
| 230 | free(ptr: server); |
| 231 | return NULL; |
| 232 | } |
| 233 | return server; |
| 234 | } |
| 235 | |
| 236 | void rpc_server_set_on_connect(rpc_server_t *server, rpc_server_on_connect_t on_connect) |
| 237 | { |
| 238 | server->on_connect = on_connect; |
| 239 | } |
| 240 | |
| 241 | void rpc_server_set_on_disconnect(rpc_server_t *server, rpc_server_on_disconnect_t on_disconnect) |
| 242 | { |
| 243 | server->on_disconnect = on_disconnect; |
| 244 | } |
| 245 | |
| 246 | void rpc_server_close(rpc_server_t *server) |
| 247 | { |
| 248 | syscall_io_close(server->server_fd); |
| 249 | server->server_fd = (ipcfd_t) -1; |
| 250 | } |
| 251 | |
| 252 | void rpc_server_destroy(rpc_server_t *server) |
| 253 | { |
| 254 | if (!IS_ERR_VALUE(server->server_fd)) |
| 255 | syscall_io_close(server->server_fd); |
| 256 | if (server->functions) |
| 257 | free(ptr: server->functions); |
| 258 | free(ptr: server); |
| 259 | } |
| 260 | |
| 261 | void rpc_server_set_data(rpc_server_t *server, void *data) |
| 262 | { |
| 263 | server->data = data; |
| 264 | } |
| 265 | |
| 266 | void *rpc_server_get_data(rpc_server_t *server) |
| 267 | { |
| 268 | return server->data; |
| 269 | } |
| 270 | |
| 271 | void rpc_server_exec(rpc_server_t *server) |
| 272 | { |
| 273 | while (true) |
| 274 | { |
| 275 | const ipcfd_t client_fd = syscall_ipc_accept(server->server_fd); |
| 276 | |
| 277 | if (IS_ERR_VALUE(client_fd)) |
| 278 | { |
| 279 | if ((long) client_fd == -EINTR) |
| 280 | continue; |
| 281 | |
| 282 | if ((long) client_fd == -ECONNABORTED) |
| 283 | break; // server closed |
| 284 | |
| 285 | #if !defined(__MOS_KERNEL__) |
| 286 | errno = -client_fd; |
| 287 | #endif |
| 288 | break; |
| 289 | } |
| 290 | |
| 291 | rpc_context_t *context = (rpc_context_t *) malloc(size: sizeof(rpc_context_t)); |
| 292 | memset(s: context, c: 0, n: sizeof(rpc_context_t)); |
| 293 | context->server = server; |
| 294 | context->client_fd = client_fd; |
| 295 | start_thread("rpc-worker" , rpc_handle_client, context); |
| 296 | } |
| 297 | } |
| 298 | |
| 299 | bool rpc_server_register_functions(rpc_server_t *server, const rpc_function_info_t *functions, size_t count) |
| 300 | { |
| 301 | MOS_LIB_ASSERT_X(server->functions == NULL, "cannot register multiple times" ); |
| 302 | server->functions = (rpc_function_info_t *) malloc(size: sizeof(rpc_function_info_t) * count); |
| 303 | memcpy(dest: server->functions, src: functions, n: sizeof(rpc_function_info_t) * count); |
| 304 | server->functions_count = count; |
| 305 | return true; |
| 306 | } |
| 307 | |
| 308 | void *rpc_context_get_data(const rpc_context_t *context) |
| 309 | { |
| 310 | return context->data; |
| 311 | } |
| 312 | |
| 313 | void *rpc_context_set_data(rpc_context_t *context, void *data) |
| 314 | { |
| 315 | void *old = NULL; |
| 316 | __atomic_exchange(&context->data, &data, &old, __ATOMIC_SEQ_CST); |
| 317 | return old; |
| 318 | } |
| 319 | |
| 320 | rpc_server_t *rpc_context_get_server(const rpc_context_t *context) |
| 321 | { |
| 322 | return context->server; |
| 323 | } |
| 324 | |
| 325 | MOSAPI int rpc_context_get_function_id(const rpc_context_t *context) |
| 326 | { |
| 327 | if (!context->request) |
| 328 | return -1; |
| 329 | return context->request->function_id; |
| 330 | } |
| 331 | |
| 332 | const void *rpc_arg_next(rpc_context_t *context, size_t *size) |
| 333 | { |
| 334 | if (context->arg_iter.next_arg_index >= context->request->args_count) |
| 335 | return NULL; |
| 336 | |
| 337 | rpc_args_iter_t *const args = &context->arg_iter; |
| 338 | |
| 339 | const size_t next_arg_byte = args->next_arg_byte; |
| 340 | |
| 341 | const rpc_arg_t *arg = (rpc_arg_t *) &context->request->args_array[next_arg_byte]; |
| 342 | if (arg->magic != RPC_ARG_MAGIC) |
| 343 | return NULL; |
| 344 | |
| 345 | args->next_arg_index++; |
| 346 | args->next_arg_byte += sizeof(rpc_arg_t) + arg->size; |
| 347 | |
| 348 | if (size) |
| 349 | *size = arg->size; |
| 350 | |
| 351 | return arg->data; |
| 352 | } |
| 353 | |
| 354 | const void *rpc_arg_sized_next(rpc_context_t *context, size_t expected_size) |
| 355 | { |
| 356 | size_t size = 0; |
| 357 | const void *data = rpc_arg_next(context, size: &size); |
| 358 | if (size != expected_size) |
| 359 | return NULL; |
| 360 | return (void *) data; |
| 361 | } |
| 362 | |
| 363 | #define RPC_ARG_NEXT_IMPL(type, TYPE) \ |
| 364 | type rpc_arg_next_##type(rpc_context_t *context) \ |
| 365 | { \ |
| 366 | return *(type *) rpc_arg_next(context, NULL); \ |
| 367 | } |
| 368 | |
| 369 | RPC_ARG_NEXT_IMPL(u8, UINT8) |
| 370 | RPC_ARG_NEXT_IMPL(u16, UINT16) |
| 371 | RPC_ARG_NEXT_IMPL(u32, UINT32) |
| 372 | RPC_ARG_NEXT_IMPL(u64, UINT64) |
| 373 | RPC_ARG_NEXT_IMPL(s8, INT8) |
| 374 | RPC_ARG_NEXT_IMPL(s16, INT16) |
| 375 | RPC_ARG_NEXT_IMPL(s32, INT32) |
| 376 | RPC_ARG_NEXT_IMPL(s64, INT64) |
| 377 | |
| 378 | const char *rpc_arg_next_string(rpc_context_t *context) |
| 379 | { |
| 380 | return (const char *) rpc_arg_next(context, NULL); |
| 381 | } |
| 382 | |
| 383 | const void *rpc_arg(const rpc_context_t *context, size_t iarg, rpc_argtype_t type, size_t *argsize) |
| 384 | { |
| 385 | // iterate over arguments |
| 386 | const char *ptr = context->request->args_array; |
| 387 | for (size_t i = 0; i < iarg; i++) |
| 388 | { |
| 389 | const rpc_arg_t *arg = (const rpc_arg_t *) ptr; |
| 390 | MOS_LIB_ASSERT(arg->magic == RPC_ARG_MAGIC); |
| 391 | ptr += sizeof(rpc_arg_t) + arg->size; |
| 392 | } |
| 393 | |
| 394 | const rpc_arg_t *arg = (const rpc_arg_t *) ptr; |
| 395 | MOS_LIB_ASSERT(arg->magic == RPC_ARG_MAGIC); |
| 396 | MOS_LIB_ASSERT(arg->argtype == type); |
| 397 | if (argsize) |
| 398 | *argsize = arg->size; |
| 399 | return arg->data; |
| 400 | } |
| 401 | |
| 402 | #define RPC_GET_ARG_IMPL(type, TYPE) \ |
| 403 | type rpc_arg_##type(const rpc_context_t *context, size_t iarg) \ |
| 404 | { \ |
| 405 | return *(type *) rpc_arg(context, iarg, RPC_ARGTYPE_##TYPE, NULL); \ |
| 406 | } |
| 407 | |
| 408 | RPC_GET_ARG_IMPL(u8, UINT32) |
| 409 | RPC_GET_ARG_IMPL(u16, UINT32) |
| 410 | RPC_GET_ARG_IMPL(u32, UINT32) |
| 411 | RPC_GET_ARG_IMPL(u64, UINT64) |
| 412 | RPC_GET_ARG_IMPL(s8, INT32) |
| 413 | RPC_GET_ARG_IMPL(s16, INT32) |
| 414 | RPC_GET_ARG_IMPL(s32, INT32) |
| 415 | RPC_GET_ARG_IMPL(s64, INT64) |
| 416 | |
| 417 | const char *rpc_arg_string(const rpc_context_t *context, size_t iarg) |
| 418 | { |
| 419 | return (const char *) rpc_arg(context, iarg, type: RPC_ARGTYPE_STRING, NULL); |
| 420 | } |
| 421 | |
| 422 | void rpc_write_result(rpc_context_t *context, const void *data, size_t size) |
| 423 | { |
| 424 | MOS_LIB_ASSERT_X(context->response == NULL, "rpc_write_result called twice" ); |
| 425 | |
| 426 | rpc_response_t *response = (rpc_response_t *) malloc(size: sizeof(rpc_response_t) + size); |
| 427 | response->magic = RPC_RESPONSE_MAGIC; |
| 428 | response->call_id = context->request->call_id; |
| 429 | response->result_code = RPC_RESULT_OK; |
| 430 | response->data_size = size; |
| 431 | memcpy(dest: response->data, src: data, n: size); |
| 432 | context->response = response; |
| 433 | } |
| 434 | |
| 435 | bool rpc_arg_pb(rpc_context_t *context, const pb_msgdesc_t *fields, void *val, size_t argid) |
| 436 | { |
| 437 | size_t size = 0; |
| 438 | const void *payload = rpc_arg(context, iarg: argid, type: RPC_ARGTYPE_BUFFER, argsize: &size); |
| 439 | pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) payload, msglen: size); |
| 440 | return pb_decode(stream: &stream, fields, dest_struct: val); |
| 441 | } |
| 442 | |
| 443 | void rpc_write_result_pb(rpc_context_t *context, const pb_msgdesc_t *type_fields, const void *val) |
| 444 | { |
| 445 | size_t bufsize; |
| 446 | pb_get_encoded_size(size: &bufsize, fields: type_fields, src_struct: val); |
| 447 | pb_byte_t buffer[bufsize]; |
| 448 | pb_ostream_t stream = pb_ostream_from_buffer(buf: buffer, bufsize); |
| 449 | const int retval = pb_encode(stream: &stream, fields: type_fields, src_struct: val); |
| 450 | if (retval) |
| 451 | rpc_write_result(context, data: buffer, size: stream.bytes_written); |
| 452 | } |
| 453 | |