1// SPDX-License-Identifier: GPL-3.0-or-later
2
3#include "librpc/rpc_server.h"
4
5#include "librpc/internal.h"
6#include "librpc/rpc.h"
7
8#include <libipc/ipc.h>
9#include <mos/types.h>
10#include <pb.h>
11#include <pb_decode.h>
12#include <pb_encode.h>
13
14#if defined(__MOS_KERNEL__)
15#include <mos/lib/sync/mutex.h>
16#include <mos_stdio.h>
17#include <mos_stdlib.h>
18#include <mos_string.h>
19#else
20#include <assert.h>
21#include <errno.h>
22#include <stdio.h>
23#include <stdlib.h>
24#include <string.h>
25#define MOS_LIB_ASSERT_X(cond, msg) assert(cond &&msg)
26#define MOS_LIB_ASSERT(cond) assert(cond)
27#endif
28
29#ifdef __MOS_KERNEL__
30#include "mos/io/io.h"
31#include "mos/ipc/ipc_io.h"
32#include "mos/tasks/kthread.h"
33
34#include <mos/syscall/decl.h>
35#define syscall_ipc_create(server_name, max_pending) ipc_create(server_name, max_pending)
36#define syscall_ipc_accept(server_fd) io_ref(ipc_accept(server_fd))
37#define syscall_ipc_connect(server_name, smh_size) ipc_connect(server_name, smh_size)
38#define start_thread(name, func, arg) kthread_create(func, arg, name)
39#define syscall_io_close(fd) io_unref(fd)
40#else
41#include <mos/syscall/usermode.h>
42#endif
43
44#if !defined(__MOS_KERNEL__)
45// fixup for hosted libc
46#include <pthread.h>
47typedef pthread_mutex_t mutex_t;
48#define memzero(ptr, size) memset(ptr, 0, size)
49#define mutex_acquire(mutex) pthread_mutex_lock(mutex)
50#define mutex_release(mutex) pthread_mutex_unlock(mutex)
51#define mos_warn(fmt, ...) fprintf(stderr, fmt "\n", ##__VA_ARGS__)
52#define MOS_LIB_UNREACHABLE() __builtin_unreachable()
53static void start_thread(const char *name, thread_entry_t entry, void *arg)
54{
55 union
56 {
57 thread_entry_t entry;
58 void *(*func)(void *);
59 } u = { entry }; // to make the compiler happy
60 pthread_t thread;
61 pthread_create(&thread, NULL, u.func, arg);
62 pthread_setname_np(thread, name);
63}
64#endif
65
66#define RPC_SERVER_MAX_PENDING_CALLS 32
67
68typedef struct _rpc_server
69{
70 const char *server_name;
71 void *data;
72 ipcfd_t server_fd;
73 size_t functions_count;
74 rpc_function_info_t *functions;
75 rpc_server_on_connect_t on_connect;
76 rpc_server_on_disconnect_t on_disconnect;
77} rpc_server_t;
78
79typedef struct _rpc_args_iter
80{
81 size_t next_arg_index;
82 size_t next_arg_byte;
83} rpc_args_iter_t;
84
85struct _rpc_reply_wrapper
86{
87 rpc_response_t *response; // may be relocated by rpc_fill_result
88};
89
90struct _rpc_context
91{
92 ipcfd_t client_fd;
93 rpc_server_t *server;
94 rpc_request_t *request;
95 rpc_response_t *response;
96 rpc_args_iter_t arg_iter;
97 void *data;
98};
99
100static inline rpc_function_info_t *rpc_server_get_function(rpc_server_t *server, u32 function_id)
101{
102 for (size_t i = 0; i < server->functions_count; i++)
103 if (server->functions[i].function_id == function_id)
104 return &server->functions[i];
105 return NULL;
106}
107
108static void rpc_handle_client(void *arg)
109{
110 rpc_context_t *context = (rpc_context_t *) arg;
111
112 if (context->server->on_connect)
113 context->server->on_connect(context);
114
115 while (true)
116 {
117 ipc_msg_t *const msg = ipc_read_msg(fd: context->client_fd);
118 if (!msg)
119 break;
120
121 if (msg->size < sizeof(rpc_request_t))
122 {
123 mos_warn("failed to read message from client");
124 ipc_msg_destroy(buffer: msg);
125 break;
126 }
127
128 rpc_request_t *request = (rpc_request_t *) msg->data;
129 if (request->magic != RPC_REQUEST_MAGIC)
130 {
131 mos_warn("invalid magic in rpc request: %x", request->magic);
132 ipc_msg_destroy(buffer: msg);
133 break;
134 }
135
136 rpc_function_info_t *function = rpc_server_get_function(server: context->server, function_id: request->function_id);
137 if (!function)
138 {
139 mos_warn("invalid function id in rpc request: %d", request->function_id);
140 ipc_msg_destroy(buffer: msg);
141 break;
142 }
143
144 if (request->args_count > RPC_MAX_ARGS)
145 {
146 mos_warn("too many arguments in rpc request: %d", request->args_count);
147 ipc_msg_destroy(buffer: msg);
148 break;
149 }
150
151 if (request->args_count != function->args_count)
152 {
153 mos_warn("invalid number if arguments in rpc request, expected %d, got %d", function->args_count, request->args_count);
154 ipc_msg_destroy(buffer: msg);
155 break;
156 }
157
158 // check argument types
159 const char *argptr = request->args_array;
160 for (size_t i = 0; i < request->args_count; i++)
161 {
162 const rpc_arg_t *arg = (const rpc_arg_t *) argptr;
163 if (arg->magic != RPC_ARG_MAGIC)
164 {
165 mos_warn("invalid magic in rpc argument: %x", arg->magic);
166 ipc_msg_destroy(buffer: msg);
167 break;
168 }
169 if (arg->argtype != function->args_type[i])
170 {
171 mos_warn("invalid argument type in rpc request, expected %d, got %d", function->args_type[i], arg->argtype);
172 ipc_msg_destroy(buffer: msg);
173 break;
174 }
175 argptr += sizeof(rpc_arg_t) + arg->size;
176 }
177 context->request = request;
178 context->response = NULL;
179 context->arg_iter = (rpc_args_iter_t){ 0 };
180
181 const rpc_result_code_t result = function->func(context);
182
183 if (context->response == NULL)
184 {
185 context->response = malloc(sizeof(rpc_response_t));
186 context->response->magic = RPC_RESPONSE_MAGIC;
187 context->response->call_id = request->call_id;
188 context->response->data_size = 0;
189 }
190
191 context->response->result_code = result;
192
193 const bool written = ipc_write_as_msg(context->client_fd, (const char *) context->response, sizeof(rpc_response_t) + context->response->data_size);
194
195 ipc_msg_destroy(msg);
196 free(context->response);
197 context->response = NULL, context->request = NULL, context->arg_iter = (rpc_args_iter_t){ 0 };
198
199 if (!written)
200 {
201 mos_warn("failed to write reply to client");
202 break;
203 }
204 }
205
206 if (context->server->on_disconnect)
207 context->server->on_disconnect(context);
208
209 syscall_io_close(context->client_fd);
210 free(context);
211}
212
213rpc_server_t *rpc_server_create(const char *server_name, void *data)
214{
215 rpc_server_t *server = malloc(sizeof(rpc_server_t));
216 memzero(s: server, n: sizeof(rpc_server_t));
217 server->server_name = server_name;
218 server->data = data;
219#ifndef __MOS_KERNEL__
220 server->server_fd = -1;
221#endif
222 server->functions_count = 0;
223 server->functions = NULL;
224 server->server_fd = syscall_ipc_create(server_name, RPC_SERVER_MAX_PENDING_CALLS);
225 if (IS_ERR_VALUE(server->server_fd))
226 {
227#if !defined(__MOS_KERNEL__)
228 errno = -server->server_fd;
229#endif
230 free(server);
231 return NULL;
232 }
233 return server;
234}
235
236void rpc_server_set_on_connect(rpc_server_t *server, rpc_server_on_connect_t on_connect)
237{
238 server->on_connect = on_connect;
239}
240
241void rpc_server_set_on_disconnect(rpc_server_t *server, rpc_server_on_disconnect_t on_disconnect)
242{
243 server->on_disconnect = on_disconnect;
244}
245
246void rpc_server_close(rpc_server_t *server)
247{
248 syscall_io_close(server->server_fd);
249 server->server_fd = (ipcfd_t) -1;
250}
251
252void rpc_server_destroy(rpc_server_t *server)
253{
254 if (!IS_ERR_VALUE(server->server_fd))
255 syscall_io_close(server->server_fd);
256 if (server->functions)
257 free(server->functions);
258 free(server);
259}
260
261void rpc_server_set_data(rpc_server_t *server, void *data)
262{
263 server->data = data;
264}
265
266void *rpc_server_get_data(rpc_server_t *server)
267{
268 return server->data;
269}
270
271void rpc_server_exec(rpc_server_t *server)
272{
273 while (true)
274 {
275 const ipcfd_t client_fd = syscall_ipc_accept(server->server_fd);
276
277 if (IS_ERR_VALUE(client_fd))
278 {
279 if ((long) client_fd == -EINTR)
280 continue;
281
282 if ((long) client_fd == -ECONNABORTED)
283 break; // server closed
284
285#if !defined(__MOS_KERNEL__)
286 errno = -client_fd;
287#endif
288 break;
289 }
290
291 rpc_context_t *context = malloc(sizeof(rpc_context_t));
292 memset(s: context, c: 0, n: sizeof(rpc_context_t));
293 context->server = server;
294 context->client_fd = client_fd;
295 start_thread("rpc-worker", rpc_handle_client, context);
296 }
297}
298
299bool rpc_server_register_functions(rpc_server_t *server, const rpc_function_info_t *functions, size_t count)
300{
301 MOS_LIB_ASSERT_X(server->functions == NULL, "cannot register multiple times");
302 server->functions = malloc(sizeof(rpc_function_info_t) * count);
303 memcpy(dest: server->functions, src: functions, n: sizeof(rpc_function_info_t) * count);
304 server->functions_count = count;
305 return true;
306}
307
308void *rpc_context_get_data(const rpc_context_t *context)
309{
310 return context->data;
311}
312
313void *rpc_context_set_data(rpc_context_t *context, void *data)
314{
315 void *old = NULL;
316 __atomic_exchange(&context->data, &data, &old, __ATOMIC_SEQ_CST);
317 return old;
318}
319
320rpc_server_t *rpc_context_get_server(const rpc_context_t *context)
321{
322 return context->server;
323}
324
325MOSAPI int rpc_context_get_function_id(const rpc_context_t *context)
326{
327 if (!context->request)
328 return -1;
329 return context->request->function_id;
330}
331
332const void *rpc_arg_next(rpc_context_t *context, size_t *size)
333{
334 if (context->arg_iter.next_arg_index >= context->request->args_count)
335 return NULL;
336
337 rpc_args_iter_t *const args = &context->arg_iter;
338
339 const size_t next_arg_byte = args->next_arg_byte;
340
341 const rpc_arg_t *arg = (rpc_arg_t *) &context->request->args_array[next_arg_byte];
342 if (arg->magic != RPC_ARG_MAGIC)
343 return NULL;
344
345 args->next_arg_index++;
346 args->next_arg_byte += sizeof(rpc_arg_t) + arg->size;
347
348 if (size)
349 *size = arg->size;
350
351 return arg->data;
352}
353
354const void *rpc_arg_sized_next(rpc_context_t *context, size_t expected_size)
355{
356 size_t size = 0;
357 const void *data = rpc_arg_next(context, size: &size);
358 if (size != expected_size)
359 return NULL;
360 return (void *) data;
361}
362
363#define RPC_ARG_NEXT_IMPL(type, TYPE) \
364 type rpc_arg_next_##type(rpc_context_t *context) \
365 { \
366 return *(type *) rpc_arg_next(context, NULL); \
367 }
368
369RPC_ARG_NEXT_IMPL(u8, UINT8)
370RPC_ARG_NEXT_IMPL(u16, UINT16)
371RPC_ARG_NEXT_IMPL(u32, UINT32)
372RPC_ARG_NEXT_IMPL(u64, UINT64)
373RPC_ARG_NEXT_IMPL(s8, INT8)
374RPC_ARG_NEXT_IMPL(s16, INT16)
375RPC_ARG_NEXT_IMPL(s32, INT32)
376RPC_ARG_NEXT_IMPL(s64, INT64)
377
378const char *rpc_arg_next_string(rpc_context_t *context)
379{
380 return rpc_arg_next(context, NULL);
381}
382
383const void *rpc_arg(const rpc_context_t *context, size_t iarg, rpc_argtype_t type, size_t *argsize)
384{
385 // iterate over arguments
386 const char *ptr = context->request->args_array;
387 for (size_t i = 0; i < iarg; i++)
388 {
389 const rpc_arg_t *arg = (const rpc_arg_t *) ptr;
390 MOS_LIB_ASSERT(arg->magic == RPC_ARG_MAGIC);
391 ptr += sizeof(rpc_arg_t) + arg->size;
392 }
393
394 const rpc_arg_t *arg = (const rpc_arg_t *) ptr;
395 MOS_LIB_ASSERT(arg->magic == RPC_ARG_MAGIC);
396 MOS_LIB_ASSERT(arg->argtype == type);
397 if (argsize)
398 *argsize = arg->size;
399 return arg->data;
400}
401
402#define RPC_GET_ARG_IMPL(type, TYPE) \
403 type rpc_arg_##type(const rpc_context_t *context, size_t iarg) \
404 { \
405 return *(type *) rpc_arg(context, iarg, RPC_ARGTYPE_##TYPE, NULL); \
406 }
407
408RPC_GET_ARG_IMPL(u8, UINT32)
409RPC_GET_ARG_IMPL(u16, UINT32)
410RPC_GET_ARG_IMPL(u32, UINT32)
411RPC_GET_ARG_IMPL(u64, UINT64)
412RPC_GET_ARG_IMPL(s8, INT32)
413RPC_GET_ARG_IMPL(s16, INT32)
414RPC_GET_ARG_IMPL(s32, INT32)
415RPC_GET_ARG_IMPL(s64, INT64)
416
417const char *rpc_arg_string(const rpc_context_t *context, size_t iarg)
418{
419 return (const char *) rpc_arg(context, iarg, type: RPC_ARGTYPE_STRING, NULL);
420}
421
422void rpc_write_result(rpc_context_t *context, const void *data, size_t size)
423{
424 MOS_LIB_ASSERT_X(context->response == NULL, "rpc_write_result called twice");
425
426 rpc_response_t *response = malloc(sizeof(rpc_response_t) + size);
427 response->magic = RPC_RESPONSE_MAGIC;
428 response->call_id = context->request->call_id;
429 response->result_code = RPC_RESULT_OK;
430 response->data_size = size;
431 memcpy(dest: response->data, src: data, n: size);
432 context->response = response;
433}
434
435bool rpc_arg_pb(rpc_context_t *context, const pb_msgdesc_t *fields, void *val, size_t argid)
436{
437 size_t size = 0;
438 const void *payload = rpc_arg(context, iarg: argid, type: RPC_ARGTYPE_BUFFER, argsize: &size);
439 pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) payload, msglen: size);
440 return pb_decode(stream: &stream, fields, dest_struct: val);
441}
442
443void rpc_write_result_pb(rpc_context_t *context, const pb_msgdesc_t *type_fields, const void *val)
444{
445 size_t bufsize;
446 pb_get_encoded_size(size: &bufsize, fields: type_fields, src_struct: val);
447 pb_byte_t buffer[bufsize];
448 pb_ostream_t stream = pb_ostream_from_buffer(buf: buffer, bufsize);
449 const int retval = pb_encode(stream: &stream, fields: type_fields, src_struct: val);
450 if (retval)
451 rpc_write_result(context, data: buffer, size: stream.bytes_written);
452}
453