| 1 | // SPDX-License-Identifier: GPL-3.0-or-later |
| 2 | |
| 3 | #include "librpc/rpc_client.h" |
| 4 | |
| 5 | #include "librpc/internal.h" |
| 6 | #include "librpc/rpc.h" |
| 7 | |
| 8 | #include <atomic> |
| 9 | #include <libipc/ipc.h> |
| 10 | #include <mos/types.h> |
| 11 | #include <pb_decode.h> |
| 12 | #include <pb_encode.h> |
| 13 | #include <stdarg.h> |
| 14 | |
| 15 | #if defined(__MOS_KERNEL__) |
| 16 | #include <mos/lib/sync/mutex.hpp> |
| 17 | #include <mos_stdio.hpp> |
| 18 | #include <mos_stdlib.hpp> |
| 19 | #include <mos_string.hpp> |
| 20 | #else |
| 21 | #include <stdio.h> |
| 22 | #include <stdlib.h> |
| 23 | #include <string.h> |
| 24 | #endif |
| 25 | |
| 26 | #ifdef __MOS_KERNEL__ |
| 27 | #include "mos/assert.hpp" |
| 28 | #include "mos/ipc/ipc_io.hpp" |
| 29 | |
| 30 | #include <mos/platform/platform.hpp> |
| 31 | #include <mos/syscall/decl.h> |
| 32 | #define syscall_ipc_connect(n, s) ipc_connect(n, s).get() |
| 33 | #define syscall_io_close(fd) fd->unref() |
| 34 | #else |
| 35 | #include <mos/syscall/usermode.h> |
| 36 | #endif |
| 37 | |
| 38 | #if !defined(__MOS_KERNEL__) |
| 39 | // fixup for hosted libc |
| 40 | #include <pthread.h> |
| 41 | typedef pthread_mutex_t mutex_t; |
| 42 | #define mutex_acquire(mutex) pthread_mutex_lock(mutex) |
| 43 | #define mutex_release(mutex) pthread_mutex_unlock(mutex) |
| 44 | #define mos_warn(...) fprintf(stderr, __VA_ARGS__) |
| 45 | #define MOS_LIB_UNREACHABLE() __builtin_unreachable() |
| 46 | #endif |
| 47 | |
| 48 | #define RPC_CLIENT_SMH_SIZE MOS_PAGE_SIZE |
| 49 | |
| 50 | typedef struct rpc_server_stub |
| 51 | { |
| 52 | const char *server_name; |
| 53 | ipcfd_t fd; |
| 54 | mutex_t mutex; // only one call at a time |
| 55 | std::atomic_size_t callid; |
| 56 | } rpc_server_stub_t; |
| 57 | |
| 58 | typedef struct rpc_call |
| 59 | { |
| 60 | rpc_server_stub_t *server; |
| 61 | rpc_request_t *request; |
| 62 | size_t size; |
| 63 | mutex_t mutex; |
| 64 | } rpc_call_t; |
| 65 | |
| 66 | rpc_server_stub_t *rpc_client_create(const char *server_name) |
| 67 | { |
| 68 | rpc_server_stub_t *client = (rpc_server_stub_t *) calloc(nmemb: 1, size: sizeof(rpc_server_stub_t)); |
| 69 | client->server_name = server_name; |
| 70 | client->fd = syscall_ipc_connect(server_name, RPC_CLIENT_SMH_SIZE); |
| 71 | |
| 72 | if (IS_ERR_VALUE(client->fd)) |
| 73 | { |
| 74 | free(ptr: client); |
| 75 | return NULL; |
| 76 | } |
| 77 | |
| 78 | return client; |
| 79 | } |
| 80 | |
| 81 | void rpc_client_destroy(rpc_server_stub_t *server) |
| 82 | { |
| 83 | mutex_acquire(mutex: &server->mutex); |
| 84 | syscall_io_close(server->fd); |
| 85 | free(ptr: server); |
| 86 | } |
| 87 | |
| 88 | rpc_call_t *rpc_call_create(rpc_server_stub_t *server, u32 function_id) |
| 89 | { |
| 90 | rpc_call_t *call = (rpc_call_t *) calloc(nmemb: 1, size: sizeof(rpc_call_t)); |
| 91 | call->request = (rpc_request_t *) calloc(nmemb: 1, size: sizeof(rpc_request_t)); |
| 92 | call->request->magic = RPC_REQUEST_MAGIC; |
| 93 | call->request->function_id = function_id; |
| 94 | call->size = sizeof(rpc_request_t); |
| 95 | call->server = server; |
| 96 | |
| 97 | return call; |
| 98 | } |
| 99 | |
| 100 | void rpc_call_destroy(rpc_call_t *call) |
| 101 | { |
| 102 | mutex_acquire(mutex: &call->mutex); |
| 103 | free(ptr: call->request); |
| 104 | free(ptr: call); |
| 105 | } |
| 106 | |
| 107 | void rpc_call_arg(rpc_call_t *call, rpc_argtype_t argtype, const void *data, size_t size) |
| 108 | { |
| 109 | mutex_acquire(mutex: &call->mutex); |
| 110 | call->request = (rpc_request_t *) realloc(ptr: call->request, size: call->size + sizeof(rpc_arg_t) + size); |
| 111 | call->request->args_count += 1; |
| 112 | |
| 113 | rpc_arg_t *arg = (rpc_arg_t *) &call->request->args_array[call->size - sizeof(rpc_request_t)]; |
| 114 | arg->size = size; |
| 115 | arg->argtype = argtype; |
| 116 | arg->magic = RPC_ARG_MAGIC; |
| 117 | memcpy(dest: arg->data, src: data, n: size); |
| 118 | |
| 119 | call->size += sizeof(rpc_arg_t) + size; |
| 120 | mutex_release(mutex: &call->mutex); |
| 121 | } |
| 122 | |
| 123 | #define RPC_CALL_ARG_IMPL(type, TYPE) \ |
| 124 | void rpc_call_arg_##type(rpc_call_t *call, type arg) \ |
| 125 | { \ |
| 126 | rpc_call_arg(call, RPC_ARGTYPE_##TYPE, &arg, sizeof(arg)); \ |
| 127 | } |
| 128 | |
| 129 | RPC_CALL_ARG_IMPL(u8, UINT8) |
| 130 | RPC_CALL_ARG_IMPL(u32, UINT32) |
| 131 | RPC_CALL_ARG_IMPL(u64, UINT64) |
| 132 | RPC_CALL_ARG_IMPL(s8, INT8) |
| 133 | RPC_CALL_ARG_IMPL(s32, INT32) |
| 134 | RPC_CALL_ARG_IMPL(s64, INT64) |
| 135 | |
| 136 | void rpc_call_arg_string(rpc_call_t *call, const char *arg) |
| 137 | { |
| 138 | rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator |
| 139 | } |
| 140 | |
| 141 | rpc_result_code_t rpc_call_exec(rpc_call_t *call, void **result_data, size_t *data_size) |
| 142 | { |
| 143 | if (result_data && data_size) |
| 144 | { |
| 145 | *data_size = 0; |
| 146 | *result_data = NULL; |
| 147 | } |
| 148 | |
| 149 | mutex_acquire(mutex: &call->mutex); |
| 150 | mutex_acquire(mutex: &call->server->mutex); |
| 151 | call->request->call_id = ++call->server->callid; |
| 152 | |
| 153 | bool written = ipc_write_as_msg(fd: call->server->fd, data: (char *) call->request, size: call->size); |
| 154 | if (!written) |
| 155 | { |
| 156 | mutex_release(mutex: &call->server->mutex); |
| 157 | mutex_release(mutex: &call->mutex); |
| 158 | return RPC_RESULT_CLIENT_WRITE_FAILED; |
| 159 | } |
| 160 | |
| 161 | ipc_msg_t *msg = ipc_read_msg(fd: call->server->fd); |
| 162 | if (!msg) |
| 163 | { |
| 164 | mutex_release(mutex: &call->server->mutex); |
| 165 | mutex_release(mutex: &call->mutex); |
| 166 | return RPC_RESULT_CLIENT_READ_FAILED; |
| 167 | } |
| 168 | |
| 169 | if (msg->size < sizeof(rpc_response_t)) |
| 170 | { |
| 171 | mutex_release(mutex: &call->server->mutex); |
| 172 | mutex_release(mutex: &call->mutex); |
| 173 | return RPC_RESULT_CLIENT_READ_FAILED; |
| 174 | } |
| 175 | |
| 176 | rpc_response_t *response = (rpc_response_t *) msg->data; |
| 177 | if (response->magic != RPC_RESPONSE_MAGIC) |
| 178 | { |
| 179 | mutex_release(mutex: &call->server->mutex); |
| 180 | mutex_release(mutex: &call->mutex); |
| 181 | return RPC_RESULT_CLIENT_READ_FAILED; |
| 182 | } |
| 183 | |
| 184 | if (response->call_id != call->request->call_id) |
| 185 | { |
| 186 | mutex_release(mutex: &call->server->mutex); |
| 187 | mutex_release(mutex: &call->mutex); |
| 188 | return RPC_RESULT_CALLID_MISMATCH; |
| 189 | } |
| 190 | |
| 191 | if (response->result_code != RPC_RESULT_OK) |
| 192 | { |
| 193 | mutex_release(mutex: &call->server->mutex); |
| 194 | mutex_release(mutex: &call->mutex); |
| 195 | return response->result_code; |
| 196 | } |
| 197 | |
| 198 | if (msg->size < sizeof(rpc_response_t)) |
| 199 | { |
| 200 | mutex_release(mutex: &call->server->mutex); |
| 201 | mutex_release(mutex: &call->mutex); |
| 202 | return RPC_RESULT_CLIENT_READ_FAILED; |
| 203 | } |
| 204 | |
| 205 | if (result_data && data_size && response->data_size) |
| 206 | { |
| 207 | *data_size = response->data_size; |
| 208 | *result_data = malloc(size: response->data_size); |
| 209 | memcpy(dest: *result_data, src: response->data, n: response->data_size); |
| 210 | } |
| 211 | |
| 212 | rpc_result_code_t result = response->result_code; |
| 213 | ipc_msg_destroy(buffer: msg); |
| 214 | mutex_release(mutex: &call->server->mutex); |
| 215 | mutex_release(mutex: &call->mutex); |
| 216 | return result; |
| 217 | } |
| 218 | |
| 219 | rpc_result_code_t rpc_simple_call(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, ...) |
| 220 | { |
| 221 | va_list args; |
| 222 | va_start(args, argspec); |
| 223 | rpc_result_code_t code = rpc_simple_callv(stub, funcid, result, argspec, args); |
| 224 | va_end(args); |
| 225 | return code; |
| 226 | } |
| 227 | |
| 228 | rpc_result_code_t rpc_simple_callv(rpc_server_stub_t *stub, u32 funcid, rpc_result_t *result, const char *argspec, va_list args) |
| 229 | { |
| 230 | if (unlikely(!argspec)) |
| 231 | { |
| 232 | mos_warn("argspec is NULL" ); |
| 233 | return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
| 234 | } |
| 235 | |
| 236 | rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid); |
| 237 | |
| 238 | if (*argspec == 'v') |
| 239 | { |
| 240 | if (*++argspec != '\0') |
| 241 | { |
| 242 | mos_warn("argspec is not empty after 'v' (void) (argspec='%s')" , argspec); |
| 243 | rpc_call_destroy(call); |
| 244 | return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
| 245 | } |
| 246 | goto exec; |
| 247 | } |
| 248 | |
| 249 | for (const char *c = argspec; *c != '\0'; c++) |
| 250 | { |
| 251 | switch (*c) |
| 252 | { |
| 253 | case 'c': |
| 254 | { |
| 255 | u8 arg = va_arg(args, int); |
| 256 | rpc_call_arg(call, argtype: RPC_ARGTYPE_UINT8, data: &arg, size: sizeof(arg)); |
| 257 | break; |
| 258 | } |
| 259 | case 'i': |
| 260 | { |
| 261 | u32 arg = va_arg(args, int); |
| 262 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT32, data: &arg, size: sizeof(arg)); |
| 263 | break; |
| 264 | } |
| 265 | case 'l': |
| 266 | { |
| 267 | u64 arg = va_arg(args, long long); |
| 268 | rpc_call_arg(call, argtype: RPC_ARGTYPE_INT64, data: &arg, size: sizeof(arg)); |
| 269 | break; |
| 270 | } |
| 271 | case 'f': |
| 272 | { |
| 273 | MOS_LIB_UNREACHABLE(); // TODO: implement |
| 274 | // double arg = va_arg(args, double); |
| 275 | // rpc_call_arg(call, &arg, sizeof(arg)); |
| 276 | break; |
| 277 | } |
| 278 | case 's': |
| 279 | { |
| 280 | const char *arg = va_arg(args, const char *); |
| 281 | rpc_call_arg(call, argtype: RPC_ARGTYPE_STRING, data: arg, size: strlen(str: arg) + 1); // also send the null terminator |
| 282 | break; |
| 283 | } |
| 284 | default: mos_warn("rpc_call: invalid argspec '%c'" , *c); return RPC_RESULT_CLIENT_INVALID_ARGSPEC; |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | exec: |
| 289 | rpc_call_exec(call, result_data: &result->data, data_size: &result->size); |
| 290 | rpc_call_destroy(call); |
| 291 | |
| 292 | return RPC_RESULT_OK; |
| 293 | } |
| 294 | |
| 295 | rpc_result_code_t rpc_do_pb_call(rpc_server_stub_t *stub, u32 funcid, const pb_msgdesc_t *reqm, const void *req, const pb_msgdesc_t *respm, void *resp) |
| 296 | { |
| 297 | size_t bufsize; |
| 298 | pb_get_encoded_size(size: &bufsize, fields: reqm, src_struct: req); |
| 299 | pb_byte_t buf[bufsize]; |
| 300 | pb_ostream_t wstream = pb_ostream_from_buffer(buf, bufsize); |
| 301 | if (!pb_encode(stream: &wstream, fields: reqm, src_struct: req)) |
| 302 | return RPC_RESULT_CLIENT_WRITE_FAILED; |
| 303 | |
| 304 | rpc_call_t *call = rpc_call_create(server: stub, function_id: funcid); |
| 305 | rpc_call_arg(call, argtype: RPC_ARGTYPE_BUFFER, data: buf, size: wstream.bytes_written); |
| 306 | |
| 307 | void *result = NULL; |
| 308 | size_t result_size = 0; |
| 309 | rpc_result_code_t result_code = rpc_call_exec(call, result_data: &result, data_size: &result_size); |
| 310 | rpc_call_destroy(call); |
| 311 | |
| 312 | if (!respm || !resp) |
| 313 | return result_code; // no response expected |
| 314 | |
| 315 | if (result_code != RPC_RESULT_OK) |
| 316 | return result_code; |
| 317 | |
| 318 | pb_istream_t stream = pb_istream_from_buffer(buf: (const pb_byte_t *) result, msglen: result_size); |
| 319 | if (!pb_decode(stream: &stream, fields: respm, dest_struct: resp)) |
| 320 | return RPC_RESULT_CLIENT_READ_FAILED; |
| 321 | |
| 322 | free(ptr: result); |
| 323 | return RPC_RESULT_OK; |
| 324 | } |
| 325 | |