Git Notes

October 12, 2021 at 8:42 pm
Artifact
# Rebase

## Change last commit

git commit --amend --date "$(date)"

# MISC

## Add tag

git tag -a <tag> <commit>
git push --tags

## Delete file in commit (but not in FS)

git rm <file_name> --cached

WAMR+Teaclave

July 15, 2021 at 10:03 am
Analysis Artifact SGX System WASM

Questions

Mesapy Executor

The executor is compiled into a library and will later be linked with ffi functions from other executor (mesapy).

How does T know where to link extern functions in Rust?

How is the linked mesapy library generated?

Check the makefile of mesapy

Notes

How to compile the WebAssembly

/opt/wasi-sdk/bin/clang -o simple_add.wasm simple_add.c -Wl,--export-all -Wl,--no-entry -nostdlib -Wl,--allow-undefined

The compiled wasm should not linked to stdlib currently, and adding --export-all can help it export functions. Otherwise the function would never be exported.

Modification on source code

Please refer the the Developer Guide and related documents.

WASM Executor

S-FaaS: Trustworthy and Accountable Function-as-a-Service using Intel SGX

September 12, 2021 at 10:13 pm
Application Paper Security SGX System

PDF

  • Resource accounting on SGX “enclaved” FaaS.
  • Trusted timer: built using TSX + additional timer thread
  • Model: function trusted by user, but not service provider(platform) => sandbox
  • KMS, transitive attestation, encryption
  • Implementation on Apache OpenWhisk

FaaS Papers, mainly in TEE

September 14, 2021 at 10:54 pm
Application Research Security SGX System

S-FaaS: Trustworthy and Accountable Function-as-a-Service using Intel SGX

PDF

  • Resource accounting on SGX “enclaved” FaaS.
  • Trusted timer: built using TSX + additional timer thread
  • Model: function trusted by user, but not service provider(platform) => sandbox
  • KMS, transitive attestation, encryption
  • Implementation on Apache OpenWhisk

Towards Demystifying Serverless Machine Learning Training

PDF

  • Implement a serverless distributed ML framework, LambdaML, including distributed optimization, communication (with a storage server) and synchronization
  • Compare the FaaS and IaaS solution for distributed ML.

Trust more, serverless

Link

  • JS FaaS in SGX enclaves
  • Google V8 engine/Duntape + SGX LKL + Apache OpenWhisk
  • Key management
  • Parallel, warm start, adjust to load

Clemmys: Towards Secure Remote Execution in FaaS

PDF

  • SGX2: DMM to speedup enclave init
  • OpenWhisk + Scone + Palaemon(KMS)
  • Gateway(T) + Controller(U) + Worker(T): not all in the enclave
  • Features: function chaining & verification
  • Functions should be manually inspected

Other related papers

Confidential Serverless Made Efficient with Plug-In Enclaves

Secure Computing

August 27, 2021 at 5:06 pm
Research Security

Secure Computation is not what I think.

Generally speaking, secure computation is to let 2(or more) mutually distrusted parties compute on both parties' data without exposing the plaintext of the data.

This problem originates in this paper: Protocols for Secure Computations (PDF). The famous millionaires' problem came from this paper and then Garbled Circuit(Wiki) became popular in this field. Dr. Qizhi Yao is the initiator of field and he is also the founder of several other fields in crypto and computer theory.

Related Projects

Related Papers

  • ObliVM: A Programming Framework for Secure Computation (PDF)
  • MAGE: Nearly Zero-Cost Virtual Memory for Secure Computation (PDF)

Teaclave

August 20, 2021 at 5:45 pm
Analysis Artifact SGX System

Syscall Fuzzers

August 14, 2021 at 3:10 pm
Fuzz Security System

SGX Middleware (Papers)

May 6, 2021 at 11:46 am
Research Security SGX

Middlewares

Container-like

Language-specific

SDK/API/Runtimes

Parallel

Attacks

SecurityPapersByTopic

April 5, 2021 at 10:06 am
Research Security

SGX Papers

August 13, 2021 at 7:54 pm
Iago Research Security SGX

Meta

Official Documents

Reviews

Attacks:

Defenses:

Applications:

See also:

Related materials other than paper

WAMR Analysis

August 13, 2021 at 7:55 pm
Analysis Artifact Security System WASM

SGX Environment App Code

Big Picture

  • An ecall ecall_handle_command to serve incoming commands from the app part
  • Memory management? runtime_malloc
void
ecall_handle_command(unsigned cmd,
                     unsigned char *cmd_buf,
                     unsigned cmd_buf_size)
{
    uint64 *args = (uint64 *)cmd_buf;
    uint32 argc = cmd_buf_size / sizeof(uint64);

    switch (cmd) {
        case CMD_INIT_RUNTIME:
            handle_cmd_init_runtime(args, argc);
            break;
        case CMD_LOAD_MODULE:
            handle_cmd_load_module(args, argc);
            break;
        case CMD_SET_WASI_ARGS:
            handle_cmd_set_wasi_args(args, argc);
            break;
        case CMD_INSTANTIATE_MODULE:
            handle_cmd_instantiate_module(args, argc);
            break;
        case CMD_LOOKUP_FUNCTION:
            break;
        case CMD_CREATE_EXEC_ENV:
            break;
        case CMD_CALL_WASM:
            break;
        case CMD_EXEC_APP_FUNC:
            handle_cmd_exec_app_func(args, argc);
            break;
        case CMD_EXEC_APP_MAIN:
            handle_cmd_exec_app_main(args, argc);
            break;
        case CMD_GET_EXCEPTION:
            handle_cmd_get_exception(args, argc);
            break;
        case CMD_DEINSTANTIATE_MODULE:
            handle_cmd_deinstantiate_module(args, argc);
            break;
        case CMD_UNLOAD_MODULE:
            handle_cmd_unload_module(args, argc);
            break;
        case CMD_DESTROY_RUNTIME:
            handle_cmd_destroy_runtime();
            break;
        case CMD_SET_LOG_LEVEL:
            handle_cmd_set_log_level(args, argc);
            break;
        default:
            LOG_ERROR("Unknown command %d\n", cmd);
            break;
    }
}

Initialization Steps

  1. enclave_init
  2. Argument parsing
  3. WAMR `init_runtime` (in enclave)
  4. Set log level
  5. Read in WASM file to a buffer and load it as a module into the enclave
  6. Instantiate the module
  7. Execute the main function
  8. Do cleaning jobs: wasm_runtime_deinstantiate, wasm_runtime_unload and wasm_runtime_destroy
int
main(int argc, char *argv[])
{
    char *wasm_file = NULL;
    const char *func_name = NULL;
    uint8_t *wasm_file_buf = NULL;
    uint32_t wasm_file_size;
    uint32_t stack_size = 16 * 1024, heap_size = 16 * 1024;
    void *wasm_module = NULL;
    void *wasm_module_inst = NULL;
    char error_buf[128] = { 0 };
    int log_verbose_level = 2;
    bool is_repl_mode = false, alloc_with_pool = false;
    const char *dir_list[8] = { NULL };
    uint32_t dir_list_size = 0;
    const char *env_list[8] = { NULL };
    uint32_t env_list_size = 0;
    uint32_t max_thread_num = 4;

    if (enclave_init(&g_eid) < 0) {
        std::cout << "Fail to initialize enclave." << std::endl;
        return 1;
    }

#if TEST_OCALL_API != 0
    {
        if (!init_runtime(alloc_with_pool, max_thread_num)) {
            return -1;
        }
        ecall_iwasm_test(g_eid);
        destroy_runtime();
        return 0;
    }
#endif

    /* Process options. */
    for (argc--, argv++; argc > 0 && argv[0][0] == '-'; argc--, argv++) {
        if (!strcmp(argv[0], "-f") || !strcmp(argv[0], "--function")) {
            argc--, argv++;
            if (argc < 2) {
                print_help();
                return 0;
            }
            func_name = argv[0];
        }
        else if (!strncmp(argv[0], "-v=", 3)) {
            log_verbose_level = atoi(argv[0] + 3);
            if (log_verbose_level < 0 || log_verbose_level > 5)
                return print_help();
        }
        else if (!strcmp(argv[0], "--repl")) {
            is_repl_mode = true;
        }
        else if (!strncmp(argv[0], "--stack-size=", 13)) {
            if (argv[0][13] == '\0')
                return print_help();
            stack_size = atoi(argv[0] + 13);
        }
        else if (!strncmp(argv[0], "--heap-size=", 12)) {
            if (argv[0][12] == '\0')
                return print_help();
            heap_size = atoi(argv[0] + 12);
        }
        else if (!strncmp(argv[0], "--dir=", 6)) {
            if (argv[0][6] == '\0')
                return print_help();
            if (dir_list_size >= sizeof(dir_list) / sizeof(char *)) {
                printf("Only allow max dir number %d\n",
                       (int)(sizeof(dir_list) / sizeof(char *)));
                return -1;
            }
            dir_list[dir_list_size++] = argv[0] + 6;
        }
        else if (!strncmp(argv[0], "--env=", 6)) {
            char *tmp_env;

            if (argv[0][6] == '\0')
                return print_help();
            if (env_list_size >= sizeof(env_list) / sizeof(char *)) {
                printf("Only allow max env number %d\n",
                       (int)(sizeof(env_list) / sizeof(char *)));
                return -1;
            }
            tmp_env = argv[0] + 6;
            if (validate_env_str(tmp_env))
                env_list[env_list_size++] = tmp_env;
            else {
                printf("Wasm parse env string failed: expect \"key=value\", "
                       "got \"%s\"\n",
                       tmp_env);
                return print_help();
            }
        }
        else if (!strncmp(argv[0], "--max-threads=", 14)) {
            if (argv[0][14] == '\0')
                return print_help();
            max_thread_num = atoi(argv[0] + 14);
        }
        else
            return print_help();
    }

    if (argc == 0)
        return print_help();

    wasm_file = argv[0];

    /* Init runtime */
    if (!init_runtime(alloc_with_pool, max_thread_num)) {
        return -1;
    }

    /* Set log verbose level */
    if (!set_log_verbose_level(log_verbose_level)) {
        goto fail1;
    }

    /* Load WASM byte buffer from WASM bin file */
    if (!(wasm_file_buf =
            (uint8_t *)read_file_to_buffer(wasm_file, &wasm_file_size))) {
        goto fail1;
    }

    /* Load module */
    if (!(wasm_module = load_module(wasm_file_buf, wasm_file_size,
                                    error_buf, sizeof(error_buf)))) {
        printf("%s\n", error_buf);
        goto fail2;
    }

    /* Set wasi arguments */
    if (!set_wasi_args(wasm_module, dir_list, dir_list_size,
                       env_list, env_list_size, argv, argc)) {
        printf("%s\n", "set wasi arguments failed.\n");
        goto fail3;
    }

    /* Instantiate module */
    if (!(wasm_module_inst = instantiate_module(wasm_module,
                                                stack_size, heap_size,
                                                error_buf,
                                                sizeof(error_buf)))) {
        printf("%s\n", error_buf);
        goto fail3;
    }

    if (is_repl_mode)
        app_instance_repl(wasm_module_inst, argc, argv);
    else if (func_name)
        app_instance_func(wasm_module_inst, func_name,
                          argc - 1, argv + 1);
    else
        app_instance_main(wasm_module_inst, argc, argv);

    /* Deinstantiate module */
    deinstantiate_module(wasm_module_inst);

fail3:
    /* Unload module */
    unload_module(wasm_module);

fail2:
    /* Free the file buffer */
    free(wasm_file_buf);

fail1:
    /* Destroy runtime environment */
    destroy_runtime();

    return 0;
}

wasmtime

July 22, 2021 at 12:19 pm
Analysis Artifact System WASM

TVM Code Generation

July 23, 2021 at 12:22 pm
Analysis Artifact WASM

WASM Code Generation Debug Note

Analysis

TVM Build

This Python interface checks the validity of the input and invokes codegen.build_module

codegen.build_module

def build_module(mod, target):
    """Build IRModule into Module.

    Parameters
    ----------
    mod : tvm.IRModule
        The ir module.

    target : str
        The target module type.

    Returns
    -------
    module : runtime.Module
        The corressponding module.
    """
    target = Target(target) if isinstance(target, str) else target
    return _ffi_api.Build(mod, target)

FFI

ffi is a very mysterious part in TVM. The _ffi_api is:

"""FFI APIs for tvm.target"""
import tvm._ffi


tvm._ffi._init_api("target", __name__)

Then:

def _init_api(namespace, target_module_name=None):
    """Initialize api for a given module name

    namespace : str
       The namespace of the source registry

    target_module_name : str
       The target module name if different from namespace
    """
    target_module_name = target_module_name if target_module_name else namespace
    if namespace.startswith("tvm."):
        _init_api_prefix(target_module_name, namespace[4:])
    else:
        _init_api_prefix(target_module_name, namespace)


def _init_api_prefix(module_name, prefix):
    module = sys.modules[module_name]

    for name in list_global_func_names():
        if not name.startswith(prefix):
            continue

        fname = name[len(prefix) + 1 :]
        target_module = module

        if fname.find(".") != -1:
            continue
        f = get_global_func(name)
        ff = _get_api(f)
        ff.__name__ = fname
        ff.__doc__ = "TVM PackedFunc %s. " % fname
        setattr(target_module, ff.__name__, ff)

It seems the FFI part in TVM will automatically collect all exported FFIs.

So I then looked at the related module.

src/target/codegen.cc

There is a Build method which is also registered:

runtime::Module Build(IRModule mod, Target target) {
  if (transform::PassContext::Current()
          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
          .value()) {
    mod = tir::transform::SkipAssert()(mod);
  }

  // the build function.
  std::string build_f_name = "target.build." + target->kind->name;
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  ICHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}

// Some code here

TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);

Here this function calls target.build.llvm indirectly in my case.

src/target/llvm/llvm_module.cc

Just like the previous one, the function has been registered:

TVM_REGISTER_GLOBAL("target.build.llvm")
    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
      auto n = make_object<LLVMModuleNode>();
      n->Init(mod, target);
      return runtime::Module(n);
    });

And this function invokes LLVMModuleNode::Init

At the beginning it does some parameter checks.

The function here invokes CodeGenLLVM::Init, which initializes the code generator for LLVM.

The core functions for generating functions in LLVM are:


void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }

void CodeGenLLVM::InitFuncState() {
  var_map_.clear();
  alias_var_set_.clear();
  alloc_storage_info_.clear();
  volatile_buf_.clear();
  analyzer_.reset(new arith::Analyzer());
}

void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
  this->InitFuncState();

  ICHECK_EQ(f->buffer_map.size(), 0U)
      << "Cannot codegen function with buffer_map, please lower them first";

  std::vector<llvm::Type*> param_types;
  is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
  for (Var param : f->params) {
    param_types.push_back(GetLLVMType(param));
    if (!is_restricted_ && param.dtype().is_handle()) {
      alias_var_set_.insert(param.get());
    }
  }
  // TODO(tvm-team):
  // Update the function type to respect the ret_type field of f.
  // Once we allow more flexibility in the PrimFunc.
  llvm::FunctionType* ftype =
      llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
  ICHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
      << "Function " << global_symbol << " already exist in module";

  function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
                                     global_symbol.value().operator std::string(), module_.get());
  function_->setCallingConv(llvm::CallingConv::C);
  function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);

  // set var map and align information
  auto arg_it = function_->arg_begin();
  for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
    llvm::Argument* v = &(*arg_it);
    const Var& var = f->params[i];
    var_map_[var.get()] = v;
    if (is_restricted_) {
      if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
        // set non alias.
#if TVM_LLVM_VERSION >= 50
        function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else
        function_->setDoesNotAlias(i + 1);
#endif
      }
    }
  }
  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
  builder_->SetInsertPoint(entry);
  this->VisitStmt(f->body);

  // Add alignment attribute if needed.
#if TVM_LLVM_VERSION >= 50
  for (size_t i = 0; i < f->params.size(); ++i) {
    const Var& var = f->params[i];
    auto f = alloc_storage_info_.find(var.get());
    if (f != alloc_storage_info_.end()) {
      unsigned align = f->second.alignment;
      if (align > 1) {
        auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
        function_->addParamAttr(i, attr);
      }
    }
  }
#endif

  llvm::StringRef fs = target_machine_->getTargetFeatureString();
  if (!fs.empty()) {
    function_->addFnAttr("target-features", fs);
  }

  if (ret_void) {
    builder_->CreateRetVoid();
  } else {
    builder_->CreateRet(ConstInt32(0));
  }
}

However, there is no function that have the same signature as BackendPackedCFunc.

Tir Level Transformation

After searching for a while, I finally found a place for BackendPackedCFunc generation. It invoked in the Python script: mod_host, mdev = _build_for_device(input_mod, tar, target_host):

def _build_for_device(input_mod, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.

    Parameters
    ----------
    input_mod : IRModule
        The schedule to be built.

    target : str or :any:`tvm.target.Target`
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target`
        The host compilation target.

    Returns
    -------
    fhost : IRModule
        The host IRModule.

    mdev : tvm.module
        A module that contains device code.
    """
    target, target_host = Target.check_and_update_host_consist(target, target_host)
    device_type = ndarray.device(target.kind.name, 0).device_type

    mod_mixed = input_mod
    mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)

    opt_mixed = [tvm.tir.transform.VerifyMemory()]
    if len(mod_mixed.functions) == 1:
        opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]

    if PassContext.current().config.get("tir.detect_global_barrier", False):
        opt_mixed += [tvm.tir.transform.ThreadSync("global")]
    opt_mixed += [
        tvm.tir.transform.ThreadSync("shared"),
        tvm.tir.transform.ThreadSync("warp"),
        tvm.tir.transform.InferFragment(),
        tvm.tir.transform.LowerThreadAllreduce(),
        tvm.tir.transform.MakePackedAPI(),
        tvm.tir.transform.SplitHostDevice(),
    ]
    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)

    # device optimizations
    opt_device = tvm.transform.Sequential(
        [
            tvm.tir.transform.Filter(
                lambda f: "calling_conv" in f.attrs
                and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
            ),
            tvm.tir.transform.LowerWarpMemory(),
            tvm.tir.transform.Simplify(),
            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
            tvm.tir.transform.LowerCustomDatatypes(),
            tvm.tir.transform.LowerIntrin(),
        ]
    )
    mod_dev = opt_device(mod_mixed)

    # host optimizations
    opt_host = tvm.transform.Sequential(
        [
            tvm.tir.transform.Filter(
                lambda f: "calling_conv" not in f.attrs
                or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
            ),
            tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
            tvm.tir.transform.LowerTVMBuiltin(),
            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
            tvm.tir.transform.LowerCustomDatatypes(),
            tvm.tir.transform.LowerIntrin(),
            tvm.tir.transform.CombineContextCall(),
        ]
    )
    mod_host = opt_host(mod_mixed)

    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert len(mod_dev.functions) == 0
    if "gpu" in target.keys and len(mod_dev.functions) == 0:
        warnings.warn(
            "Specified target %s, but cannot find device code, did you do " "bind?" % target
        )

    rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
    return mod_host, rt_mod_dev

Here the function MakePackedAPI generates the function with BackendPackedCFunc signature.

PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";

  auto target = func->GetAttr<Target>(tvm::attr::kTarget);
  ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
  int target_device_type = target.value()->kind->device_type;

  std::string name_hint = global_symbol.value();

  auto* func_ptr = func.CopyOnWrite();
  const Stmt nop = Evaluate(0);
  int num_args = static_cast<int>(func_ptr->params.size());
  ICHECK_LE(num_unpacked_args, num_args);
  bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
  if (num_unpacked_args == -1) {
    // reset to zero
    num_unpacked_args = 0;
  }
  ICHECK_GE(num_unpacked_args, 0);
  int num_packed_args = num_args - num_unpacked_args;
  // Data field definitions
  // The packed fields
  Var v_packed_args("args", DataType::Handle());
  Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
  Var v_num_packed_args("num_args", DataType::Int(32));
  Var v_out_ret_value("out_ret_value", DataType::Handle());
  Var v_out_ret_tcode("out_ret_tcode", DataType::Handle());
  Var v_resource_handle("resource_handle", DataType::Handle());
  // The arguments of the function.
  Array<Var> args;
  // The device context
  Var device_id("dev_id");
  Integer device_type(target_device_type);
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after init
  std::vector<Stmt> seq_init, seq_check;
  std::unordered_map<const VarNode*, PrimExpr> vmap;
  ArgBinder binder(&vmap);
  // ---------------------------
  // local function definitions
  // load i-th argument as type t
  auto f_arg_value = [&](DataType t, int i) {
    Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
                              IntImm(DataType::Int(32), builtin::kTVMValueContent)};
    // load 64 bit version
    DataType api_type = APIType(t);
    PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
    // cast to the target version.
    if (api_type != t) {
      res = Cast(t, res);
    }
    return res;
  };
  // ---------------------------
  // start of logics
  // add signiture for packed arguments.
  if (pack_args) {
    args.push_back(v_packed_args);
    args.push_back(v_packed_arg_type_ids);
    args.push_back(v_num_packed_args);
    std::ostringstream os;

    os << name_hint << ": num_args should be " << num_packed_args;
    seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
  }

  // Need to re-declare vars, in case some arguments also appears in the buffer.
  std::vector<std::pair<Var, Var> > var_def;
  std::vector<std::pair<Var, Buffer> > buffer_def;

  for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
    Var param = func_ptr->params[i];
    Var v_arg = Var("arg" + std::to_string(i), param->dtype);

    auto it = func_ptr->buffer_map.find(param);
    if (it != func_ptr->buffer_map.end()) {
      buffer_def.emplace_back(v_arg, (*it).second);
    } else {
      var_def.emplace_back(v_arg, param);
    }
    if (i < num_packed_args) {
      // Value loads
      seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
      // type code checks
      Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
      seq_init.emplace_back(LetStmt(tcode,
                                    Load(DataType::Int(32), v_packed_arg_type_ids,
                                         IntImm(DataType::Int(32), i), const_true(1)),
                                    nop));
      DataType t = v_arg.dtype();
      if (t.is_handle()) {
        std::ostringstream msg;
        msg << name_hint << ": Expect arg[" << i << "] to be pointer";
        seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
                                              tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
                                          tvm::tir::StringImm(msg.str()), nop));
      } else if (t.is_int() || t.is_uint()) {
        std::ostringstream msg;
        msg << name_hint << ": Expect arg[" << i << "] to be int";
        seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
      } else {
        ICHECK(t.is_float());
        std::ostringstream msg;
        msg << name_hint << ": Expect arg[" << i << "] to be float";
        seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
      }
    } else {
      args.push_back(v_arg);
    }
  }

  // allow return value if the function is packed.
  if (pack_args) {
    args.push_back(v_out_ret_value);
    args.push_back(v_out_ret_tcode);
    args.push_back(v_resource_handle);
  }

  size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
  ICHECK_EQ(args.size(), expected_nargs);

  // Arg definitions are defined before buffer binding to avoid the use before
  // def errors.
  //
  // For example, for auto broadcasting, checks are required to guarantee that
  // either 0 or the original stride will be correctly used. Checks here have
  // to use the args that may have no let binding yet. Therefore, hoisting let
  // binding for args before buffer declaration is needed.
  for (const auto& kv : var_def) {
    binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
  }

  for (const auto& kv : buffer_def) {
    binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
  }

  if (num_unpacked_args == 0) {
    func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
  }

  Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
  body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
                  StringImm(name_hint + "_compute_"), body);
  // Set device context
  if (vmap.count(device_id.get())) {
    PrimExpr node = StringImm("default");
    seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
    seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));

    if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
      Stmt set_device =
          Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
                        {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
      body = SeqStmt({set_device, body});
    }
  }
  func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
  func_ptr->params = args;

  Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
  if (undefined.size() != 0) {
    std::ostringstream os;
    for (Var v : undefined) {
      os << " \'" << v->name_hint << "\' ";
    }
    os << " is not bound to any variables";
    LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
  }

  func_ptr->buffer_map = Map<Var, Buffer>();
  func_ptr->checked_type_ = func_ptr->func_type_annotation();
  func_ptr->ret_type = PrimType(DataType::Int(32));

  // return the function.
  return std::move(func);
}