1 | // SPDX-License-Identifier: GPL-3.0-or-later |
2 | |
3 | #include "librpc/rpc_server.h" |
4 | |
5 | #include "librpc/internal.h" |
6 | #include "librpc/rpc.h" |
7 | |
8 | #include <libipc/ipc.h> |
9 | #include <mos/types.h> |
10 | #include <pb.h> |
11 | #include <pb_decode.h> |
12 | #include <pb_encode.h> |
13 | |
14 | #if defined(__MOS_KERNEL__) |
15 | #include <mos/lib/sync/mutex.h> |
16 | #include <mos_stdio.h> |
17 | #include <mos_stdlib.h> |
18 | #include <mos_string.h> |
19 | #else |
20 | #include <assert.h> |
21 | #include <errno.h> |
22 | #include <stdio.h> |
23 | #include <stdlib.h> |
24 | #include <string.h> |
25 | #define MOS_LIB_ASSERT_X(cond, msg) assert(cond &&msg) |
26 | #define MOS_LIB_ASSERT(cond) assert(cond) |
27 | #endif |
28 | |
29 | #ifdef __MOS_KERNEL__ |
30 | #include "mos/io/io.h" |
31 | #include "mos/ipc/ipc_io.h" |
32 | #include "mos/tasks/kthread.h" |
33 | |
34 | #include <mos/syscall/decl.h> |
35 | #define syscall_ipc_create(server_name, max_pending) ipc_create(server_name, max_pending) |
36 | #define syscall_ipc_accept(server_fd) io_ref(ipc_accept(server_fd)) |
37 | #define syscall_ipc_connect(server_name, smh_size) ipc_connect(server_name, smh_size) |
38 | #define start_thread(name, func, arg) kthread_create(func, arg, name) |
39 | #define syscall_io_close(fd) io_unref(fd) |
40 | #else |
41 | #include <mos/syscall/usermode.h> |
42 | #endif |
43 | |
44 | #if !defined(__MOS_KERNEL__) |
45 | // fixup for hosted libc |
46 | #include <pthread.h> |
47 | typedef pthread_mutex_t mutex_t; |
48 | #define memzero(ptr, size) memset(ptr, 0, size) |
49 | #define mutex_acquire(mutex) pthread_mutex_lock(mutex) |
50 | #define mutex_release(mutex) pthread_mutex_unlock(mutex) |
51 | #define mos_warn(fmt, ...) fprintf(stderr, fmt "\n", ##__VA_ARGS__) |
52 | #define MOS_LIB_UNREACHABLE() __builtin_unreachable() |
53 | static void start_thread(const char *name, thread_entry_t entry, void *arg) |
54 | { |
55 | union |
56 | { |
57 | thread_entry_t entry; |
58 | void *(*func)(void *); |
59 | } u = { entry }; // to make the compiler happy |
60 | pthread_t thread; |
61 | pthread_create(&thread, NULL, u.func, arg); |
62 | pthread_setname_np(thread, name); |
63 | } |
64 | #endif |
65 | |
66 | #define RPC_SERVER_MAX_PENDING_CALLS 32 |
67 | |
68 | typedef struct _rpc_server |
69 | { |
70 | const char *server_name; |
71 | void *data; |
72 | ipcfd_t server_fd; |
73 | size_t functions_count; |
74 | rpc_function_info_t *functions; |
75 | rpc_server_on_connect_t on_connect; |
76 | rpc_server_on_disconnect_t on_disconnect; |
77 | } rpc_server_t; |
78 | |
79 | typedef struct _rpc_args_iter |
80 | { |
81 | size_t next_arg_index; |
82 | size_t next_arg_byte; |
83 | } rpc_args_iter_t; |
84 | |
85 | struct _rpc_reply_wrapper |
86 | { |
87 | rpc_response_t *response; // may be relocated by rpc_fill_result |
88 | }; |
89 | |
90 | struct _rpc_context |
91 | { |
92 | ipcfd_t client_fd; |
93 | rpc_server_t *server; |
94 | rpc_request_t *request; |
95 | rpc_response_t *response; |
96 | rpc_args_iter_t arg_iter; |
97 | void *data; |
98 | }; |
99 | |
100 | static inline rpc_function_info_t *rpc_server_get_function(rpc_server_t *server, u32 function_id) |
101 | { |
102 | for (size_t i = 0; i < server->functions_count; i++) |
103 | if (server->functions[i].function_id == function_id) |
104 | return &server->functions[i]; |
105 | return NULL; |
106 | } |
107 | |
108 | static void rpc_handle_client(void *arg) |
109 | { |
110 | rpc_context_t *context = (rpc_context_t *) arg; |
111 | |
112 | if (context->server->on_connect) |
113 | context->server->on_connect(context); |
114 | |
115 | while (true) |
116 | { |
117 | ipc_msg_t *const msg = ipc_read_msg(fd: context->client_fd); |
118 | if (!msg) |
119 | break; |
120 | |
121 | if (msg->size < sizeof(rpc_request_t)) |
122 | { |
123 | mos_warn("failed to read message from client" ); |
124 | ipc_msg_destroy(buffer: msg); |
125 | break; |
126 | } |
127 | |
128 | rpc_request_t *request = (rpc_request_t *) msg->data; |
129 | if (request->magic != RPC_REQUEST_MAGIC) |
130 | { |
131 | mos_warn("invalid magic in rpc request: %x" , request->magic); |
132 | ipc_msg_destroy(buffer: msg); |
133 | break; |
134 | } |
135 | |
136 | rpc_function_info_t *function = rpc_server_get_function(server: context->server, function_id: request->function_id); |
137 | if (!function) |
138 | { |
139 | mos_warn("invalid function id in rpc request: %d" , request->function_id); |
140 | ipc_msg_destroy(buffer: msg); |
141 | break; |
142 | } |
143 | |
144 | if (request->args_count > RPC_MAX_ARGS) |
145 | { |
146 | mos_warn("too many arguments in rpc request: %d" , request->args_count); |
147 | ipc_msg_destroy(buffer: msg); |
148 | break; |
149 | } |
150 | |
151 | if (request->args_count != function->args_count) |
152 | { |
153 | mos_warn("invalid number if arguments in rpc request, expected %d, got %d" , function->args_count, request->args_count); |
154 | ipc_msg_destroy(buffer: msg); |
155 | break; |
156 | } |
157 | |
158 | // check argument types |
159 | const char *argptr = request->args_array; |
160 | for (size_t i = 0; i < request->args_count; i++) |
161 | { |
162 | const rpc_arg_t *arg = (const rpc_arg_t *) argptr; |
163 | if (arg->magic != RPC_ARG_MAGIC) |
164 | { |
165 | mos_warn("invalid magic in rpc argument: %x" , arg->magic); |
166 | ipc_msg_destroy(buffer: msg); |
167 | break; |
168 | } |
169 | if (arg->argtype != function->args_type[i]) |
170 | { |
171 | mos_warn("invalid argument type in rpc request, expected %d, got %d" , function->args_type[i], arg->argtype); |
172 | ipc_msg_destroy(buffer: msg); |
173 | break; |
174 | } |
175 | argptr += sizeof(rpc_arg_t) + arg->size; |
176 | } |
177 | context->request = request; |
178 | context->response = NULL; |
179 | context->arg_iter = (rpc_args_iter_t){ 0 }; |
180 | |
181 | const rpc_result_code_t result = function->func(context); |
182 | |
183 | if (context->response == NULL) |
184 | { |
185 | context->response = malloc(sizeof(rpc_response_t)); |
186 | context->response->magic = RPC_RESPONSE_MAGIC; |
187 | context->response->call_id = request->call_id; |
188 | context->response->data_size = 0; |
189 | } |
190 | |
191 | context->response->result_code = result; |
192 | |
193 | const bool written = ipc_write_as_msg(context->client_fd, (const char *) context->response, sizeof(rpc_response_t) + context->response->data_size); |
194 | |
195 | ipc_msg_destroy(msg); |
196 | free(context->response); |
197 | context->response = NULL, context->request = NULL, context->arg_iter = (rpc_args_iter_t){ 0 }; |
198 | |
199 | if (!written) |
200 | { |
201 | mos_warn("failed to write reply to client" ); |
202 | break; |
203 | } |
204 | } |
205 | |
206 | if (context->server->on_disconnect) |
207 | context->server->on_disconnect(context); |
208 | |
209 | syscall_io_close(context->client_fd); |
210 | free(context); |
211 | } |
212 | |
213 | rpc_server_t *rpc_server_create(const char *server_name, void *data) |
214 | { |
215 | rpc_server_t *server = malloc(sizeof(rpc_server_t)); |
216 | memzero(s: server, n: sizeof(rpc_server_t)); |
217 | server->server_name = server_name; |
218 | server->data = data; |
219 | #ifndef __MOS_KERNEL__ |
220 | server->server_fd = -1; |
221 | #endif |
222 | server->functions_count = 0; |
223 | server->functions = NULL; |
224 | server->server_fd = syscall_ipc_create(server_name, RPC_SERVER_MAX_PENDING_CALLS); |
225 | if (IS_ERR_VALUE(server->server_fd)) |
226 | { |
227 | #if !defined(__MOS_KERNEL__) |
228 | errno = -server->server_fd; |
229 | #endif |
230 | free(server); |
231 | return NULL; |
232 | } |
233 | return server; |
234 | } |
235 | |
236 | void rpc_server_set_on_connect(rpc_server_t *server, rpc_server_on_connect_t on_connect) |
237 | { |
238 | server->on_connect = on_connect; |
239 | } |
240 | |
241 | void rpc_server_set_on_disconnect(rpc_server_t *server, rpc_server_on_disconnect_t on_disconnect) |
242 | { |
243 | server->on_disconnect = on_disconnect; |
244 | } |
245 | |
246 | void rpc_server_close(rpc_server_t *server) |
247 | { |
248 | syscall_io_close(server->server_fd); |
249 | server->server_fd = (ipcfd_t) -1; |
250 | } |
251 | |
252 | void rpc_server_destroy(rpc_server_t *server) |
253 | { |
254 | if (!IS_ERR_VALUE(server->server_fd)) |
255 | syscall_io_close(server->server_fd); |
256 | if (server->functions) |
257 | free(server->functions); |
258 | free(server); |
259 | } |
260 | |
261 | void rpc_server_set_data(rpc_server_t *server, void *data) |
262 | { |
263 | server->data = data; |
264 | } |
265 | |
266 | void *rpc_server_get_data(rpc_server_t *server) |
267 | { |
268 | return server->data; |
269 | } |
270 | |
271 | void rpc_server_exec(rpc_server_t *server) |
272 | { |
273 | while (true) |
274 | { |
275 | const ipcfd_t client_fd = syscall_ipc_accept(server->server_fd); |
276 | |
277 | if (IS_ERR_VALUE(client_fd)) |
278 | { |
279 | if ((long) client_fd == -EINTR) |
280 | continue; |
281 | |
282 | if ((long) client_fd == -ECONNABORTED) |
283 | break; // server closed |
284 | |
285 | #if !defined(__MOS_KERNEL__) |
286 | errno = -client_fd; |
287 | #endif |
288 | break; |
289 | } |
290 | |
291 | rpc_context_t *context = malloc(sizeof(rpc_context_t)); |
292 | memset(s: context, c: 0, n: sizeof(rpc_context_t)); |
293 | context->server = server; |
294 | context->client_fd = client_fd; |
295 | start_thread("rpc-worker" , rpc_handle_client, context); |
296 | } |
297 | } |
298 | |
299 | bool rpc_server_register_functions(rpc_server_t *server, const rpc_function_info_t *functions, size_t count) |
300 | { |
301 | MOS_LIB_ASSERT_X(server->functions == NULL, "cannot register multiple times" ); |
302 | server->functions = malloc(sizeof(rpc_function_info_t) * count); |
303 | memcpy(dest: server->functions, src: functions, n: sizeof(rpc_function_info_t) * count); |
304 | server->functions_count = count; |
305 | return true; |
306 | } |
307 | |
308 | void *rpc_context_get_data(const rpc_context_t *context) |
309 | { |
310 | return context->data; |
311 | } |
312 | |
313 | void *rpc_context_set_data(rpc_context_t *context, void *data) |
314 | { |
315 | void *old = NULL; |
316 | __atomic_exchange(&context->data, &data, &old, __ATOMIC_SEQ_CST); |
317 | return old; |
318 | } |
319 | |
320 | rpc_server_t *rpc_context_get_server(const rpc_context_t *context) |
321 | { |
322 | return context->server; |
323 | } |
324 | |
325 | MOSAPI int rpc_context_get_function_id(const rpc_context_t *context) |
326 | { |
327 | if (!context->request) |
328 | return -1; |
329 | return context->request->function_id; |
330 | } |
331 | |
332 | const void *rpc_arg_next(rpc_context_t *context, size_t *size) |
333 | { |
334 | if (context->arg_iter.next_arg_index >= context->request->args_count) |
335 | return NULL; |
336 | |
337 | rpc_args_iter_t *const args = &context->arg_iter; |
338 | |
339 | const size_t next_arg_byte = args->next_arg_byte; |
340 | |
341 | const rpc_arg_t *arg = (rpc_arg_t *) &context->request->args_array[next_arg_byte]; |
342 | if (arg->magic != RPC_ARG_MAGIC) |
343 | return NULL; |
344 | |
345 | args->next_arg_index++; |
346 | args->next_arg_byte += sizeof(rpc_arg_t) + arg->size; |
347 | |
348 | if (size) |
349 | *size = arg->size; |
350 | |
351 | return arg->data; |
352 | } |
353 | |
354 | const void *rpc_arg_sized_next(rpc_context_t *context, size_t expected_size) |
355 | { |
356 | size_t size = 0; |
357 | const void *data = rpc_arg_next(context, size: &size); |
358 | if (size != expected_size) |
359 | return NULL; |
360 | return (void *) data; |
361 | } |
362 | |
363 | #define RPC_ARG_NEXT_IMPL(type, TYPE) \ |
364 | type rpc_arg_next_##type(rpc_context_t *context) \ |
365 | { \ |
366 | return *(type *) rpc_arg_next(context, NULL); \ |
367 | } |
368 | |
369 | RPC_ARG_NEXT_IMPL(u8, UINT8) |
370 | RPC_ARG_NEXT_IMPL(u16, UINT16) |
371 | RPC_ARG_NEXT_IMPL(u32, UINT32) |
372 | RPC_ARG_NEXT_IMPL(u64, UINT64) |
373 | RPC_ARG_NEXT_IMPL(s8, INT8) |
374 | RPC_ARG_NEXT_IMPL(s16, INT16) |
375 | RPC_ARG_NEXT_IMPL(s32, INT32) |
376 | RPC_ARG_NEXT_IMPL(s64, INT64) |
377 | |
378 | const char *rpc_arg_next_string(rpc_context_t *context) |
379 | { |
380 | return rpc_arg_next(context, NULL); |
381 | } |
382 | |
383 | const void *rpc_arg(const rpc_context_t *context, size_t iarg, rpc_argtype_t type, size_t *argsize) |
384 | { |
385 | // iterate over arguments |
386 | const char *ptr = context->request->args_array; |
387 | for (size_t i = 0; i < iarg; i++) |
388 | { |
389 | const rpc_arg_t *arg = (const rpc_arg_t *) ptr; |
390 | MOS_LIB_ASSERT(arg->magic == RPC_ARG_MAGIC); |
391 | ptr += sizeof(rpc_arg_t) + arg->size; |
392 | } |
393 | |
394 | const rpc_arg_t *arg = (const rpc_arg_t *) ptr; |
395 | MOS_LIB_ASSERT(arg->magic == RPC_ARG_MAGIC); |
396 | MOS_LIB_ASSERT(arg->argtype == type); |
397 | if (argsize) |
398 | *argsize = arg->size; |
399 | return arg->data; |
400 | } |
401 | |
402 | #define RPC_GET_ARG_IMPL(type, TYPE) \ |
403 | type rpc_arg_##type(const rpc_context_t *context, size_t iarg) \ |
404 | { \ |
405 | return *(type *) rpc_arg(context, iarg, RPC_ARGTYPE_##TYPE, NULL); \ |
406 | } |
407 | |
408 | RPC_GET_ARG_IMPL(u8, UINT32) |
409 | RPC_GET_ARG_IMPL(u16, UINT32) |
410 | RPC_GET_ARG_IMPL(u32, UINT32) |
411 | RPC_GET_ARG_IMPL(u64, UINT64) |
412 | RPC_GET_ARG_IMPL(s8, INT32) |
413 | RPC_GET_ARG_IMPL(s16, INT32) |
414 | RPC_GET_ARG_IMPL(s32, INT32) |
415 | RPC_GET_ARG_IMPL(s64, INT64) |
416 | |
417 | const char *rpc_arg_string(const rpc_context_t *context, size_t iarg) |
418 | { |
419 | return (const char *) rpc_arg(context, iarg, type: RPC_ARGTYPE_STRING, NULL); |
420 | } |
421 | |
422 | void rpc_write_result(rpc_context_t *context, const void *data, size_t size) |
423 | { |
424 | MOS_LIB_ASSERT_X(context->response == NULL, "rpc_write_result called twice" ); |
425 | |
426 | rpc_response_t *response = malloc(sizeof(rpc_response_t) + size); |
427 | response->magic = RPC_RESPONSE_MAGIC; |
428 | response->call_id = context->request->call_id; |
429 | response->result_code = RPC_RESULT_OK; |
430 | response->data_size = size; |
431 | memcpy(dest: response->data, src: data, n: size); |
432 | context->response = response; |
433 | } |
434 | |
435 | bool rpc_arg_pb(rpc_context_t *context, const pb_msgdesc_t *fields, void *val, size_t argid) |
436 | { |
437 | size_t size = 0; |
438 | const void *payload = rpc_arg(context, iarg: argid, type: RPC_ARGTYPE_BUFFER, argsize: &size); |
439 | pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) payload, msglen: size); |
440 | return pb_decode(stream: &stream, fields, dest_struct: val); |
441 | } |
442 | |
443 | void rpc_write_result_pb(rpc_context_t *context, const pb_msgdesc_t *type_fields, const void *val) |
444 | { |
445 | size_t bufsize; |
446 | pb_get_encoded_size(size: &bufsize, fields: type_fields, src_struct: val); |
447 | pb_byte_t buffer[bufsize]; |
448 | pb_ostream_t stream = pb_ostream_from_buffer(buf: buffer, bufsize); |
449 | const int retval = pb_encode(stream: &stream, fields: type_fields, src_struct: val); |
450 | if (retval) |
451 | rpc_write_result(context, data: buffer, size: stream.bytes_written); |
452 | } |
453 | |