1// SPDX-License-Identifier: GPL-3.0-or-later
2
3#include "librpc/rpc_client.h"
4
5#include "librpc/internal.h"
6#include "librpc/rpc.h"
7
8#include <libipc/ipc.h>
9#include <mos/types.h>
10#include <pb_decode.h>
11#include <pb_encode.h>
12#include <stdarg.h>
13
14#if defined(__MOS_KERNEL__)
15#include <mos/lib/sync/mutex.h>
16#include <mos_stdio.h>
17#include <mos_stdlib.h>
18#include <mos_string.h>
19#else
20#include <stdio.h>
21#include <stdlib.h>
22#include <string.h>
23#endif
24
25#ifdef __MOS_KERNEL__
26#include "mos/ipc/ipc_io.h"
27
28#include <mos/platform/platform.h>
29#include <mos/syscall/decl.h>
30#define syscall_ipc_connect(n, s) ipc_connect(n, s)
31#define syscall_io_close(fd) io_unref(fd)
32#else
33#include <mos/syscall/usermode.h>
34#endif
35
36#if !defined(__MOS_KERNEL__)
37// fixup for hosted libc
38#include <pthread.h>
39typedef pthread_mutex_t mutex_t;
40#define memzero(ptr, size) memset(ptr, 0, size)
41#define mutex_acquire(mutex) pthread_mutex_lock(mutex)
42#define mutex_release(mutex) pthread_mutex_unlock(mutex)
43#define mos_warn(...) fprintf(stderr, __VA_ARGS__)
44#define MOS_LIB_UNREACHABLE() __builtin_unreachable()
45#endif
46
47#define RPC_CLIENT_SMH_SIZE MOS_PAGE_SIZE
48
49typedef struct rpc_server_stub
50{
51 const char *server_name;
52 ipcfd_t fd;
53 mutex_t mutex; // only one call at a time
54 atomic_t callid;
55} rpc_server_stub_t;
56
57typedef struct rpc_call
58{
59 rpc_server_stub_t *server;
60 rpc_request_t *request;
61 size_t size;
62 mutex_t mutex;
63} rpc_call_t;
64
65rpc_server_stub_t *rpc_client_create(const char *server_name)
66{
67 rpc_server_stub_t *client = malloc(sizeof(rpc_server_stub_t));
68 memzero(s: client, n: sizeof(rpc_server_stub_t));
69 client->server_name = server_name;
70 client->fd = syscall_ipc_connect(server_name, RPC_CLIENT_SMH_SIZE);
71
72 if (IS_ERR_VALUE(client->fd))
73 {
74 free(client);
75 return NULL;
76 }
77
78 return client;
79}
80
81void rpc_client_destroy(rpc_server_stub_t *server)
82{
83 mutex_acquire(mutex: &server->mutex);
84 syscall_io_close(server->fd);
85 free(server);
86}
87
88rpc_call_t *rpc_call_create(rpc_server_stub_t *server, u32 function_id)
89{
90 rpc_call_t *call = malloc(sizeof(rpc_call_t));
91 memzero(s: call, n: sizeof(rpc_call_t));
92
93 call->request = malloc(sizeof(rpc_request_t));
94 memzero(s: call->request, n: sizeof(rpc_request_t));
95 call->request->magic = RPC_REQUEST_MAGIC;
96 call->request->function_id = function_id;
97 call->size = sizeof(rpc_request_t);
98 call->server = server;
99
100 return call;
101}
102
103void rpc_call_destroy(rpc_call_t *call)
104{
105 mutex_acquire(mutex: &call->mutex);
106 free(call->request);
107 free(call);
108}
109
110void rpc_call_arg(rpc_call_t *call, rpc_argtype_t argtype, const void *data, size_t size)
111{
112 mutex_acquire(mutex: &call->mutex);
113 call->request = realloc(call->request, call->size + sizeof(rpc_arg_t) + size);
114 call->request->args_count += 1;
115
116 rpc_arg_t *arg = (rpc_arg_t *) &call->request->args_array[call->size - sizeof(rpc_request_t)];
117 arg->size = size;
118 arg->argtype = argtype;
119 arg->magic = RPC_ARG_MAGIC;
120 memcpy(dest: arg->data, src: data, n: size);
121
122 call->size += sizeof(rpc_arg_t) + size;
123 mutex_release(mutex: &call->mutex);
124}
125
126#define RPC_CALL_ARG_IMPL(type, TYPE) \
127 void rpc_call_arg_##type(rpc_call_t *call, type arg) \
128 { \
129 rpc_call_arg(call, RPC_ARGTYPE_##TYPE, &arg, sizeof(arg)); \
130 }
131
132RPC_CALL_ARG_IMPL(u8, UINT8)
133RPC_CALL_ARG_IMPL(u32, UINT32)
134RPC_CALL_ARG_IMPL(u64, UINT64)
135RPC_CALL_ARG_IMPL(s8, INT8)
136RPC_CALL_ARG_IMPL(s32, INT32)
137RPC_CALL_ARG_IMPL(s64, INT64)
138
139void rpc_call_arg_string(rpc_call_t *call, const char *arg)
140{
141 rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator
142}
143
144rpc_result_code_t rpc_call_exec(rpc_call_t *call, void **result_data, size_t *data_size)
145{
146 if (result_data && data_size)
147 {
148 *data_size = 0;
149 *result_data = NULL;
150 }
151
152 mutex_acquire(mutex: &call->mutex);
153 mutex_acquire(mutex: &call->server->mutex);
154 call->request->call_id = ++call->server->callid;
155
156 bool written = ipc_write_as_msg(fd: call->server->fd, data: (char *) call->request, size: call->size);
157 if (!written)
158 {
159 mutex_release(mutex: &call->server->mutex);
160 mutex_release(mutex: &call->mutex);
161 return RPC_RESULT_CLIENT_WRITE_FAILED;
162 }
163
164 ipc_msg_t *msg = ipc_read_msg(fd: call->server->fd);
165 if (!msg)
166 {
167 mutex_release(mutex: &call->server->mutex);
168 mutex_release(mutex: &call->mutex);
169 return RPC_RESULT_CLIENT_READ_FAILED;
170 }
171
172 if (msg->size < sizeof(rpc_response_t))
173 {
174 mutex_release(mutex: &call->server->mutex);
175 mutex_release(mutex: &call->mutex);
176 return RPC_RESULT_CLIENT_READ_FAILED;
177 }
178
179 rpc_response_t *response = (rpc_response_t *) msg->data;
180 if (response->magic != RPC_RESPONSE_MAGIC)
181 {
182 mutex_release(mutex: &call->server->mutex);
183 mutex_release(mutex: &call->mutex);
184 return RPC_RESULT_CLIENT_READ_FAILED;
185 }
186
187 if (response->call_id != call->request->call_id)
188 {
189 mutex_release(mutex: &call->server->mutex);
190 mutex_release(mutex: &call->mutex);
191 return RPC_RESULT_CALLID_MISMATCH;
192 }
193
194 if (response->result_code != RPC_RESULT_OK)
195 {
196 mutex_release(mutex: &call->server->mutex);
197 mutex_release(mutex: &call->mutex);
198 return response->result_code;
199 }
200
201 if (msg->size < sizeof(rpc_response_t))
202 {
203 mutex_release(mutex: &call->server->mutex);
204 mutex_release(mutex: &call->mutex);
205 return RPC_RESULT_CLIENT_READ_FAILED;
206 }
207
208 if (result_data && data_size && response->data_size)
209 {
210 *data_size = response->data_size;
211 *result_data = malloc(response->data_size);
212 memcpy(dest: *result_data, src: response->data, n: response->data_size);
213 }
214
215 rpc_result_code_t result = response->result_code;
216 ipc_msg_destroy(buffer: msg);
217 mutex_release(mutex: &call->server->mutex);
218 mutex_release(mutex: &call->mutex);
219 return result;
220}
221
222rpc_result_code_t rpc_simple_call(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, ...)
223{
224 va_list args;
225 va_start(args, argspec);
226 rpc_result_code_t code = rpc_simple_callv(stub, funcid, result, argspec, args);
227 va_end(args);
228 return code;
229}
230
231rpc_result_code_t rpc_simple_callv(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, va_list args)
232{
233 if (unlikely(!argspec))
234 {
235 mos_warn("argspec is NULL");
236 return RPC_RESULT_CLIENT_INVALID_ARGSPEC;
237 }
238
239 rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid);
240
241 if (*argspec == 'v')
242 {
243 if (*++argspec != '\0')
244 {
245 mos_warn("argspec is not empty after 'v' (void) (argspec='%s')", argspec);
246 rpc_call_destroy(call);
247 return RPC_RESULT_CLIENT_INVALID_ARGSPEC;
248 }
249 goto exec;
250 }
251
252 for (const char *c = argspec; *c != '\0'; c++)
253 {
254 switch (*c)
255 {
256 case 'c':
257 {
258 u8 arg = va_arg(args, int);
259 rpc_call_arg(call, argtype: RPC_ARGTYPE_INT8, data: &arg, size: sizeof(arg));
260 break;
261 }
262 case 'i':
263 {
264 u32 arg = va_arg(args, int);
265 rpc_call_arg(call, argtype: RPC_ARGTYPE_INT32, data: &arg, size: sizeof(arg));
266 break;
267 }
268 case 'l':
269 {
270 u64 arg = va_arg(args, long long);
271 rpc_call_arg(call, argtype: RPC_ARGTYPE_INT64, data: &arg, size: sizeof(arg));
272 break;
273 }
274 case 'f':
275 {
276 MOS_LIB_UNREACHABLE(); // TODO: implement
277 // double arg = va_arg(args, double);
278 // rpc_call_arg(call, &arg, sizeof(arg));
279 break;
280 }
281 case 's':
282 {
283 const char *arg = va_arg(args, const char *);
284 rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator
285 break;
286 }
287 default: mos_warn("rpc_call: invalid argspec '%c'", *c); return RPC_RESULT_CLIENT_INVALID_ARGSPEC;
288 }
289 }
290
291exec:
292 rpc_call_exec(call, result_data: &result->data, data_size: &result->size);
293 rpc_call_destroy(call);
294
295 return RPC_RESULT_OK;
296}
297
298rpc_result_code_t rpc_do_pb_call(rpc_server_stub_t *stub, u32 funcid, const pb_msgdesc_t *reqm, const void *req, const pb_msgdesc_t *respm, void *resp)
299{
300 size_t bufsize;
301 pb_get_encoded_size(size: &bufsize, fields: reqm, src_struct: req);
302 pb_byte_t buf[bufsize];
303 pb_ostream_t wstream = pb_ostream_from_buffer(buf, bufsize);
304 if (!pb_encode(stream: &wstream, fields: reqm, src_struct: req))
305 return RPC_RESULT_CLIENT_WRITE_FAILED;
306
307 rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid);
308 rpc_call_arg(call, argtype: RPC_ARGTYPE_BUFFER, data: buf, size: wstream.bytes_written);
309
310 void *result = NULL;
311 size_t result_size = 0;
312 rpc_result_code_t result_code = rpc_call_exec(call, result_data: &result, data_size: &result_size);
313 rpc_call_destroy(call);
314
315 if (!respm || !resp)
316 return result_code; // no response expected
317
318 if (result_code != RPC_RESULT_OK)
319 return result_code;
320
321 pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) result, msglen: result_size);
322 if (!pb_decode(stream: &stream, fields: respm, dest_struct: resp))
323 return RPC_RESULT_CLIENT_READ_FAILED;
324
325 free(result);
326 return RPC_RESULT_OK;
327}
328