1// SPDX-License-Identifier: GPL-3.0-or-later
2
3#include "librpc/rpc_client.h"
4
5#include "librpc/internal.h"
6#include "librpc/rpc.h"
7
8#include <atomic>
9#include <libipc/ipc.h>
10#include <mos/types.h>
11#include <pb_decode.h>
12#include <pb_encode.h>
13#include <stdarg.h>
14
15#if defined(__MOS_KERNEL__)
16#include <mos/lib/sync/mutex.hpp>
17#include <mos_stdio.hpp>
18#include <mos_stdlib.hpp>
19#include <mos_string.hpp>
20#else
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24#endif
25
26#ifdef __MOS_KERNEL__
27#include "mos/assert.hpp"
28#include "mos/ipc/ipc_io.hpp"
29
30#include <mos/platform/platform.hpp>
31#include <mos/syscall/decl.h>
32#define syscall_ipc_connect(n, s) ipc_connect(n, s).get()
33#define syscall_io_close(fd) io_unref(fd)
34#else
35#include <mos/syscall/usermode.h>
36#endif
37
38#if !defined(__MOS_KERNEL__)
39// fixup for hosted libc
40#include <pthread.h>
41typedef pthread_mutex_t mutex_t;
42#define mutex_acquire(mutex) pthread_mutex_lock(mutex)
43#define mutex_release(mutex) pthread_mutex_unlock(mutex)
44#define mos_warn(...) fprintf(stderr, __VA_ARGS__)
45#define MOS_LIB_UNREACHABLE() __builtin_unreachable()
46#endif
47
48#define RPC_CLIENT_SMH_SIZE MOS_PAGE_SIZE
49
50typedef struct rpc_server_stub
51{
52 const char *server_name;
53 ipcfd_t fd;
54 mutex_t mutex; // only one call at a time
55 std::atomic_size_t callid;
56} rpc_server_stub_t;
57
58typedef struct rpc_call
59{
60 rpc_server_stub_t *server;
61 rpc_request_t *request;
62 size_t size;
63 mutex_t mutex;
64} rpc_call_t;
65
66rpc_server_stub_t *rpc_client_create(const char *server_name)
67{
68 rpc_server_stub_t *client = (rpc_server_stub_t *) calloc(nmemb: 1, size: sizeof(rpc_server_stub_t));
69 client->server_name = server_name;
70 client->fd = syscall_ipc_connect(server_name, RPC_CLIENT_SMH_SIZE);
71
72 if (IS_ERR_VALUE(client->fd))
73 {
74 free(ptr: client);
75 return NULL;
76 }
77
78 return client;
79}
80
81void rpc_client_destroy(rpc_server_stub_t *server)
82{
83 mutex_acquire(mutex: &server->mutex);
84 syscall_io_close(server->fd);
85 free(ptr: server);
86}
87
88rpc_call_t *rpc_call_create(rpc_server_stub_t *server, u32 function_id)
89{
90 rpc_call_t *call = (rpc_call_t *) calloc(nmemb: 1, size: sizeof(rpc_call_t));
91 call->request = (rpc_request_t *) calloc(nmemb: 1, size: sizeof(rpc_request_t));
92 call->request->magic = RPC_REQUEST_MAGIC;
93 call->request->function_id = function_id;
94 call->size = sizeof(rpc_request_t);
95 call->server = server;
96
97 return call;
98}
99
100void rpc_call_destroy(rpc_call_t *call)
101{
102 mutex_acquire(mutex: &call->mutex);
103 free(ptr: call->request);
104 free(ptr: call);
105}
106
107void rpc_call_arg(rpc_call_t *call, rpc_argtype_t argtype, const void *data, size_t size)
108{
109 mutex_acquire(mutex: &call->mutex);
110 call->request = (rpc_request_t *) realloc(ptr: call->request, size: call->size + sizeof(rpc_arg_t) + size);
111 call->request->args_count += 1;
112
113 rpc_arg_t *arg = (rpc_arg_t *) &call->request->args_array[call->size - sizeof(rpc_request_t)];
114 arg->size = size;
115 arg->argtype = argtype;
116 arg->magic = RPC_ARG_MAGIC;
117 memcpy(dest: arg->data, src: data, n: size);
118
119 call->size += sizeof(rpc_arg_t) + size;
120 mutex_release(mutex: &call->mutex);
121}
122
123#define RPC_CALL_ARG_IMPL(type, TYPE) \
124 void rpc_call_arg_##type(rpc_call_t *call, type arg) \
125 { \
126 rpc_call_arg(call, RPC_ARGTYPE_##TYPE, &arg, sizeof(arg)); \
127 }
128
129RPC_CALL_ARG_IMPL(u8, UINT8)
130RPC_CALL_ARG_IMPL(u32, UINT32)
131RPC_CALL_ARG_IMPL(u64, UINT64)
132RPC_CALL_ARG_IMPL(s8, INT8)
133RPC_CALL_ARG_IMPL(s32, INT32)
134RPC_CALL_ARG_IMPL(s64, INT64)
135
136void rpc_call_arg_string(rpc_call_t *call, const char *arg)
137{
138 rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator
139}
140
141rpc_result_code_t rpc_call_exec(rpc_call_t *call, void **result_data, size_t *data_size)
142{
143 if (result_data && data_size)
144 {
145 *data_size = 0;
146 *result_data = NULL;
147 }
148
149 mutex_acquire(mutex: &call->mutex);
150 mutex_acquire(mutex: &call->server->mutex);
151 call->request->call_id = ++call->server->callid;
152
153 bool written = ipc_write_as_msg(fd: call->server->fd, data: (char *) call->request, size: call->size);
154 if (!written)
155 {
156 mutex_release(mutex: &call->server->mutex);
157 mutex_release(mutex: &call->mutex);
158 return RPC_RESULT_CLIENT_WRITE_FAILED;
159 }
160
161 ipc_msg_t *msg = ipc_read_msg(fd: call->server->fd);
162 if (!msg)
163 {
164 mutex_release(mutex: &call->server->mutex);
165 mutex_release(mutex: &call->mutex);
166 return RPC_RESULT_CLIENT_READ_FAILED;
167 }
168
169 if (msg->size < sizeof(rpc_response_t))
170 {
171 mutex_release(mutex: &call->server->mutex);
172 mutex_release(mutex: &call->mutex);
173 return RPC_RESULT_CLIENT_READ_FAILED;
174 }
175
176 rpc_response_t *response = (rpc_response_t *) msg->data;
177 if (response->magic != RPC_RESPONSE_MAGIC)
178 {
179 mutex_release(mutex: &call->server->mutex);
180 mutex_release(mutex: &call->mutex);
181 return RPC_RESULT_CLIENT_READ_FAILED;
182 }
183
184 if (response->call_id != call->request->call_id)
185 {
186 mutex_release(mutex: &call->server->mutex);
187 mutex_release(mutex: &call->mutex);
188 return RPC_RESULT_CALLID_MISMATCH;
189 }
190
191 if (response->result_code != RPC_RESULT_OK)
192 {
193 mutex_release(mutex: &call->server->mutex);
194 mutex_release(mutex: &call->mutex);
195 return response->result_code;
196 }
197
198 if (msg->size < sizeof(rpc_response_t))
199 {
200 mutex_release(mutex: &call->server->mutex);
201 mutex_release(mutex: &call->mutex);
202 return RPC_RESULT_CLIENT_READ_FAILED;
203 }
204
205 if (result_data && data_size && response->data_size)
206 {
207 *data_size = response->data_size;
208 *result_data = malloc(size: response->data_size);
209 memcpy(dest: *result_data, src: response->data, n: response->data_size);
210 }
211
212 rpc_result_code_t result = response->result_code;
213 ipc_msg_destroy(buffer: msg);
214 mutex_release(mutex: &call->server->mutex);
215 mutex_release(mutex: &call->mutex);
216 return result;
217}
218
219rpc_result_code_t rpc_simple_call(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, ...)
220{
221 va_list args;
222 va_start(args, argspec);
223 rpc_result_code_t code = rpc_simple_callv(stub, funcid, result, argspec, args);
224 va_end(args);
225 return code;
226}
227
228rpc_result_code_t rpc_simple_callv(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, va_list args)
229{
230 if (unlikely(!argspec))
231 {
232 mos_warn("argspec is NULL");
233 return RPC_RESULT_CLIENT_INVALID_ARGSPEC;
234 }
235
236 rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid);
237
238 if (*argspec == 'v')
239 {
240 if (*++argspec != '\0')
241 {
242 mos_warn("argspec is not empty after 'v' (void) (argspec='%s')", argspec);
243 rpc_call_destroy(call);
244 return RPC_RESULT_CLIENT_INVALID_ARGSPEC;
245 }
246 goto exec;
247 }
248
249 for (const char *c = argspec; *c != '\0'; c++)
250 {
251 switch (*c)
252 {
253 case 'c':
254 {
255 u8 arg = va_arg(args, int);
256 rpc_call_arg(call, argtype: RPC_ARGTYPE_INT8, data: &arg, size: sizeof(arg));
257 break;
258 }
259 case 'i':
260 {
261 u32 arg = va_arg(args, int);
262 rpc_call_arg(call, argtype: RPC_ARGTYPE_INT32, data: &arg, size: sizeof(arg));
263 break;
264 }
265 case 'l':
266 {
267 u64 arg = va_arg(args, long long);
268 rpc_call_arg(call, argtype: RPC_ARGTYPE_INT64, data: &arg, size: sizeof(arg));
269 break;
270 }
271 case 'f':
272 {
273 MOS_LIB_UNREACHABLE(); // TODO: implement
274 // double arg = va_arg(args, double);
275 // rpc_call_arg(call, &arg, sizeof(arg));
276 break;
277 }
278 case 's':
279 {
280 const char *arg = va_arg(args, const char *);
281 rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator
282 break;
283 }
284 default: mos_warn("rpc_call: invalid argspec '%c'", *c); return RPC_RESULT_CLIENT_INVALID_ARGSPEC;
285 }
286 }
287
288exec:
289 rpc_call_exec(call, result_data: &result->data, data_size: &result->size);
290 rpc_call_destroy(call);
291
292 return RPC_RESULT_OK;
293}
294
295rpc_result_code_t rpc_do_pb_call(rpc_server_stub_t *stub, u32 funcid, const pb_msgdesc_t *reqm, const void *req, const pb_msgdesc_t *respm, void *resp)
296{
297 size_t bufsize;
298 pb_get_encoded_size(size: &bufsize, fields: reqm, src_struct: req);
299 pb_byte_t buf[bufsize];
300 pb_ostream_t wstream = pb_ostream_from_buffer(buf, bufsize);
301 if (!pb_encode(stream: &wstream, fields: reqm, src_struct: req))
302 return RPC_RESULT_CLIENT_WRITE_FAILED;
303
304 rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid);
305 rpc_call_arg(call, argtype: RPC_ARGTYPE_BUFFER, data: buf, size: wstream.bytes_written);
306
307 void *result = NULL;
308 size_t result_size = 0;
309 rpc_result_code_t result_code = rpc_call_exec(call, result_data: &result, data_size: &result_size);
310 rpc_call_destroy(call);
311
312 if (!respm || !resp)
313 return result_code; // no response expected
314
315 if (result_code != RPC_RESULT_OK)
316 return result_code;
317
318 pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) result, msglen: result_size);
319 if (!pb_decode(stream: &stream, fields: respm, dest_struct: resp))
320 return RPC_RESULT_CLIENT_READ_FAILED;
321
322 free(ptr: result);
323 return RPC_RESULT_OK;
324}
325