WAMR `wasm_runtime_load`

 July 15, 2021 at 10:03 am

The module loader in the enclave

  1. Allocate memory for EnclaveModule struct and WASM file
  2. Copy memory(WASM file buffer) from application part
static void
handle_cmd_load_module(uint64 *args, uint32 argc)
{
    uint64 *args_org = args;
    char *wasm_file = *(char **)args++;
    uint32 wasm_file_size = *(uint32 *)args++;
    char *error_buf = *(char **)args++;
    uint32 error_buf_size = *(uint32 *)args++;
    uint64 total_size = sizeof(EnclaveModule) + (uint64)wasm_file_size;
    EnclaveModule *enclave_module;

    bh_assert(argc == 4);

    if (total_size >= UINT32_MAX
        || !(enclave_module = (EnclaveModule *)
                    wasm_runtime_malloc((uint32)total_size))) {
        set_error_buf(error_buf, error_buf_size,
                      "WASM module load failed: "
                      "allocate memory failed.");
        *(void **)args_org = NULL;
        return;
    }

    memset(enclave_module, 0, (uint32)total_size);
    enclave_module->wasm_file = (uint8 *)enclave_module
                                + sizeof(EnclaveModule);
    bh_memcpy_s(enclave_module->wasm_file, wasm_file_size,
                wasm_file, wasm_file_size);

    if (!(enclave_module->module =
                wasm_runtime_load(enclave_module->wasm_file, wasm_file_size,
                                  error_buf, error_buf_size))) {
        wasm_runtime_free(enclave_module);
        *(void **)args_org = NULL;
        return;
    }

    *(EnclaveModule **)args_org = enclave_module;

    LOG_VERBOSE("Load module success.\n");
}

And it is indirectly invoked by load_module in the application part.

static void *
load_module(uint8_t *wasm_file_buf, uint32_t wasm_file_size,
            char *error_buf, uint32_t error_buf_size)
{
    uint64_t ecall_args[4];

    ecall_args[0] = (uint64_t)(uintptr_t)wasm_file_buf;
    ecall_args[1] = wasm_file_size;
    ecall_args[2] = (uint64_t)(uintptr_t)error_buf;
    ecall_args[3] = error_buf_size;
    if (SGX_SUCCESS != ecall_handle_command(g_eid, CMD_LOAD_MODULE,
                                            (uint8_t *)ecall_args,
                                            sizeof(uint64_t) * 4)) {
        printf("Call ecall_handle_command() failed.\n");
        return NULL;
    }

    return (void *)(uintptr_t)ecall_args[0];
}

Will call wasm_runtime_load to load a module from buffer.

wasm_runtime_load

  • Check if the loaded module is bytecode
  • Load the module (depending on AOT&&JIT or INTERP)
WASMModuleCommon *
wasm_runtime_load(const uint8 *buf, uint32 size,
                  char *error_buf, uint32 error_buf_size)
{
    WASMModuleCommon *module_common = NULL;

    if (get_package_type(buf, size) == Wasm_Module_Bytecode) {
#if WASM_ENABLE_AOT != 0 && WASM_ENABLE_JIT != 0
        AOTModule *aot_module;
        WASMModule *module = wasm_load(buf, size, error_buf, error_buf_size);
        if (!module)
            return NULL;

        if (!(aot_module = aot_convert_wasm_module(module,
                                                   error_buf, error_buf_size))) {
            wasm_unload(module);
            return NULL;
        }

        module_common = (WASMModuleCommon*)aot_module;
        return register_module_with_null_name(module_common,
                                              error_buf, error_buf_size);
#elif WASM_ENABLE_INTERP != 0
        module_common = (WASMModuleCommon*)
               wasm_load(buf, size, error_buf, error_buf_size);
        return register_module_with_null_name(module_common,
                                              error_buf, error_buf_size);
#endif
    }
    else if (get_package_type(buf, size) == Wasm_Module_AoT) {
#if WASM_ENABLE_AOT != 0
        module_common = (WASMModuleCommon*)
               aot_load_from_aot_file(buf, size, error_buf, error_buf_size);
        return register_module_with_null_name(module_common,
                                              error_buf, error_buf_size);
#endif
    }

    if (size < 4)
        set_error_buf(error_buf, error_buf_size,
                      "WASM module load failed: unexpected end");
    else
       set_error_buf(error_buf, error_buf_size,
                     "WASM module load failed: magic header not detected");
    return NULL;
}

wasm_loader_load

WASMModule*
wasm_loader_load(const uint8 *buf, uint32 size, char *error_buf, uint32 error_buf_size)
{
    WASMModule *module = create_module(error_buf, error_buf_size);
    if (!module) {
        return NULL;
    }

    if (!load(buf, size, module, error_buf, error_buf_size)) {
        goto fail;
    }

    LOG_VERBOSE("Load module success.\n");
    return module;

fail:
    wasm_loader_unload(module);
    return NULL;
}

create_module

Merely create a module struct buffer and init it minimally.

static WASMModule*
create_module(char *error_buf, uint32 error_buf_size)
{
    WASMModule *module = loader_malloc(sizeof(WASMModule),
                                       error_buf, error_buf_size);

    if (!module) {
        return NULL;
    }

    module->module_type = Wasm_Module_Bytecode;

    /* Set start_function to -1, means no start function */
    module->start_function = (uint32)-1;

#if WASM_ENABLE_MULTI_MODULE != 0
    module->import_module_list = &module->import_module_list_head;
#endif
    return module;
}

load

  • checks buffer is not exhausted every time before consuming it
  1. Check the magic number
  2. Check the version info
  3. Create sections (read from WASM sections)
  4. load_from_sections using the collected section info
  5. destroy_sections: free the section linked list
static bool
load(const uint8 *buf, uint32 size, WASMModule *module,
     char *error_buf, uint32 error_buf_size)
{
    const uint8 *buf_end = buf + size;
    const uint8 *p = buf, *p_end = buf_end;
    uint32 magic_number, version;
    WASMSection *section_list = NULL;

    CHECK_BUF1(p, p_end, sizeof(uint32));
    magic_number = read_uint32(p);
    if (!is_little_endian())
        exchange32((uint8*)&magic_number);

    if (magic_number != WASM_MAGIC_NUMBER) {
        set_error_buf(error_buf, error_buf_size,
                      "magic header not detected");
        return false;
    }

    CHECK_BUF1(p, p_end, sizeof(uint32));
    version = read_uint32(p);
    if (!is_little_endian())
        exchange32((uint8*)&version);

    if (version != WASM_CURRENT_VERSION) {
        set_error_buf(error_buf, error_buf_size,
                      "unknown binary version");
        return false;
    }

    if (!create_sections(buf, size, &section_list, error_buf, error_buf_size)
        || !load_from_sections(module, section_list, error_buf, error_buf_size)) {
        destroy_sections(section_list);
        return false;
    }

    destroy_sections(section_list);
    return true;
fail:
    return false;
}

create_sections

Inintialize sections: read the section type and length => linked list of WASMSection

/* WASM section */
typedef struct wasm_section_t {
    struct wasm_section_t *next;
    /* section type */
    int section_type;
    /* section body, not include type and size */
    uint8_t *section_body;
    /* section body size */
    uint32_t section_body_size;
} wasm_section_t, aot_section_t, *wasm_section_list_t, *aot_section_list_t;
static bool
create_sections(const uint8 *buf, uint32 size,
                WASMSection **p_section_list,
                char *error_buf, uint32 error_buf_size)
{
    WASMSection *section_list_end = NULL, *section;
    const uint8 *p = buf, *p_end = buf + size/*, *section_body*/;
    uint8 section_type, section_index, last_section_index = (uint8)-1;
    uint32 section_size;

    bh_assert(!*p_section_list);

    p += 8;
    while (p < p_end) {
        CHECK_BUF(p, p_end, 1);
        section_type = read_uint8(p);
        section_index = get_section_index(section_type);
        if (section_index != (uint8)-1) {
            if (section_type != SECTION_TYPE_USER) {
                /* Custom sections may be inserted at any place,
                   while other sections must occur at most once
                   and in prescribed order. */
                bh_assert(last_section_index == (uint8)-1
                          || last_section_index < section_index);
                last_section_index = section_index;
            }
            CHECK_BUF1(p, p_end, 1);
            read_leb_uint32(p, p_end, section_size);
            CHECK_BUF1(p, p_end, section_size);

            if (!(section = loader_malloc(sizeof(WASMSection),
                                          error_buf, error_buf_size))) {
                return false;
            }

            section->section_type = section_type;
            section->section_body = (uint8*)p;
            section->section_body_size = section_size;

            if (!*p_section_list)
                *p_section_list = section_list_end = section;
            else {
                section_list_end->next = section;
                section_list_end = section;
            }

            p += section_size;
        }
        else {
            bh_assert(0);
        }
    }

    (void)last_section_index;
    return true;
}

load_from_sections

  1. Find code and function sections if have
  2. Iterate each section and extract related information To explore the details
  3. Resolve auxiliary data/stack/heap info and reset memory info (check exported values from the module)
  4. Resolve malloc/free function exported by wasm module Support custom malloc/free? What's retain function?
  5. Iterate through all functions and execute wasm_loader_prepare_bytecode for it W
  6. Additional memory initializations if it cannot grow