ya0guang's notebook

Personality Backup

SecurityPapersByTopic

April 5, 2021 at 10:06 am
Research Security

wasmtime

July 22, 2021 at 12:19 pm
Analysis Artifact System WASM

TVM Code Generation

July 23, 2021 at 12:22 pm
Analysis Artifact WASM

WASM Code Generation Debug Note

Analysis

TVM Build

This Python interface checks the validity of the input and invokes codegen.build_module

codegen.build_module

def build_module(mod, target):
    """Build IRModule into Module.

    Parameters
    ----------
    mod : tvm.IRModule
        The ir module.

    target : str
        The target module type.

    Returns
    -------
    module : runtime.Module
        The corressponding module.
    """
    target = Target(target) if isinstance(target, str) else target
    return _ffi_api.Build(mod, target)

FFI

ffi is a very mysterious part in TVM. The _ffi_api is:

"""FFI APIs for tvm.target"""
import tvm._ffi


tvm._ffi._init_api("target", __name__)

Then:

def _init_api(namespace, target_module_name=None):
    """Initialize api for a given module name

    namespace : str
       The namespace of the source registry

    target_module_name : str
       The target module name if different from namespace
    """
    target_module_name = target_module_name if target_module_name else namespace
    if namespace.startswith("tvm."):
        _init_api_prefix(target_module_name, namespace[4:])
    else:
        _init_api_prefix(target_module_name, namespace)


def _init_api_prefix(module_name, prefix):
    module = sys.modules[module_name]

    for name in list_global_func_names():
        if not name.startswith(prefix):
            continue

        fname = name[len(prefix) + 1 :]
        target_module = module

        if fname.find(".") != -1:
            continue
        f = get_global_func(name)
        ff = _get_api(f)
        ff.__name__ = fname
        ff.__doc__ = "TVM PackedFunc %s. " % fname
        setattr(target_module, ff.__name__, ff)

It seems the FFI part in TVM will automatically collect all exported FFIs.

So I then looked at the related module.

src/target/codegen.cc

There is a Build method which is also registered:

runtime::Module Build(IRModule mod, Target target) {
  if (transform::PassContext::Current()
          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
          .value()) {
    mod = tir::transform::SkipAssert()(mod);
  }

  // the build function.
  std::string build_f_name = "target.build." + target->kind->name;
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  ICHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}

// Some code here

TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);

Here this function calls target.build.llvm indirectly in my case.

src/target/llvm/llvm_module.cc

Just like the previous one, the function has been registered:

TVM_REGISTER_GLOBAL("target.build.llvm")
    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
      auto n = make_object<LLVMModuleNode>();
      n->Init(mod, target);
      return runtime::Module(n);
    });

And this function invokes LLVMModuleNode::Init

At the beginning it does some parameter checks.

The function here invokes CodeGenLLVM::Init, which initializes the code generator for LLVM.

The core functions for generating functions in LLVM are:


void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }

void CodeGenLLVM::InitFuncState() {
  var_map_.clear();
  alias_var_set_.clear();
  alloc_storage_info_.clear();
  volatile_buf_.clear();
  analyzer_.reset(new arith::Analyzer());
}

void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
  this->InitFuncState();

  ICHECK_EQ(f->buffer_map.size(), 0U)
      << "Cannot codegen function with buffer_map, please lower them first";

  std::vector<llvm::Type*> param_types;
  is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
  for (Var param : f->params) {
    param_types.push_back(GetLLVMType(param));
    if (!is_restricted_ && param.dtype().is_handle()) {
      alias_var_set_.insert(param.get());
    }
  }
  // TODO(tvm-team):
  // Update the function type to respect the ret_type field of f.
  // Once we allow more flexibility in the PrimFunc.
  llvm::FunctionType* ftype =
      llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
  ICHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
      << "Function " << global_symbol << " already exist in module";

  function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
                                     global_symbol.value().operator std::string(), module_.get());
  function_->setCallingConv(llvm::CallingConv::C);
  function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);

  // set var map and align information
  auto arg_it = function_->arg_begin();
  for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
    llvm::Argument* v = &(*arg_it);
    const Var& var = f->params[i];
    var_map_[var.get()] = v;
    if (is_restricted_) {
      if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
        // set non alias.
#if TVM_LLVM_VERSION >= 50
        function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else
        function_->setDoesNotAlias(i + 1);
#endif
      }
    }
  }
  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
  builder_->SetInsertPoint(entry);
  this->VisitStmt(f->body);

  // Add alignment attribute if needed.
#if TVM_LLVM_VERSION >= 50
  for (size_t i = 0; i < f->params.size(); ++i) {
    const Var& var = f->params[i];
    auto f = alloc_storage_info_.find(var.get());
    if (f != alloc_storage_info_.end()) {
      unsigned align = f->second.alignment;
      if (align > 1) {
        auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
        function_->addParamAttr(i, attr);
      }
    }
  }
#endif

  llvm::StringRef fs = target_machine_->getTargetFeatureString();
  if (!fs.empty()) {
    function_->addFnAttr("target-features", fs);
  }

  if (ret_void) {
    builder_->CreateRetVoid();
  } else {
    builder_->CreateRet(ConstInt32(0));
  }
}

However, there is no function that have the same signature as BackendPackedCFunc.

Tir Level Transformation

After searching for a while, I finally found a place for BackendPackedCFunc generation. It invoked in the Python script: mod_host, mdev = _build_for_device(input_mod, tar, target_host):

def _build_for_device(input_mod, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.

    Parameters
    ----------
    input_mod : IRModule
        The schedule to be built.

    target : str or :any:`tvm.target.Target`
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target`
        The host compilation target.

    Returns
    -------
    fhost : IRModule
        The host IRModule.

    mdev : tvm.module
        A module that contains device code.
    """
    target, target_host = Target.check_and_update_host_consist(target, target_host)
    device_type = ndarray.device(target.kind.name, 0).device_type

    mod_mixed = input_mod
    mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)

    opt_mixed = [tvm.tir.transform.VerifyMemory()]
    if len(mod_mixed.functions) == 1:
        opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]

    if PassContext.current().config.get("tir.detect_global_barrier", False):
        opt_mixed += [tvm.tir.transform.ThreadSync("global")]
    opt_mixed += [
        tvm.tir.transform.ThreadSync("shared"),
        tvm.tir.transform.ThreadSync("warp"),
        tvm.tir.transform.InferFragment(),
        tvm.tir.transform.LowerThreadAllreduce(),
        tvm.tir.transform.MakePackedAPI(),
        tvm.tir.transform.SplitHostDevice(),
    ]
    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)

    # device optimizations
    opt_device = tvm.transform.Sequential(
        [
            tvm.tir.transform.Filter(
                lambda f: "calling_conv" in f.attrs
                and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
            ),
            tvm.tir.transform.LowerWarpMemory(),
            tvm.tir.transform.Simplify(),
            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
            tvm.tir.transform.LowerCustomDatatypes(),
            tvm.tir.transform.LowerIntrin(),
        ]
    )
    mod_dev = opt_device(mod_mixed)

    # host optimizations
    opt_host = tvm.transform.Sequential(
        [
            tvm.tir.transform.Filter(
                lambda f: "calling_conv" not in f.attrs
                or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
            ),
            tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
            tvm.tir.transform.LowerTVMBuiltin(),
            tvm.tir.transform.LowerDeviceStorageAccessInfo(),
            tvm.tir.transform.LowerCustomDatatypes(),
            tvm.tir.transform.LowerIntrin(),
            tvm.tir.transform.CombineContextCall(),
        ]
    )
    mod_host = opt_host(mod_mixed)

    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert len(mod_dev.functions) == 0
    if "gpu" in target.keys and len(mod_dev.functions) == 0:
        warnings.warn(
            "Specified target %s, but cannot find device code, did you do " "bind?" % target
        )

    rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
    return mod_host, rt_mod_dev

Here the function MakePackedAPI generates the function with BackendPackedCFunc signature.

PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";

  auto target = func->GetAttr<Target>(tvm::attr::kTarget);
  ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
  int target_device_type = target.value()->kind->device_type;

  std::string name_hint = global_symbol.value();

  auto* func_ptr = func.CopyOnWrite();
  const Stmt nop = Evaluate(0);
  int num_args = static_cast<int>(func_ptr->params.size());
  ICHECK_LE(num_unpacked_args, num_args);
  bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
  if (num_unpacked_args == -1) {
    // reset to zero
    num_unpacked_args = 0;
  }
  ICHECK_GE(num_unpacked_args, 0);
  int num_packed_args = num_args - num_unpacked_args;
  // Data field definitions
  // The packed fields
  Var v_packed_args("args", DataType::Handle());
  Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
  Var v_num_packed_args("num_args", DataType::Int(32));
  Var v_out_ret_value("out_ret_value", DataType::Handle());
  Var v_out_ret_tcode("out_ret_tcode", DataType::Handle());
  Var v_resource_handle("resource_handle", DataType::Handle());
  // The arguments of the function.
  Array<Var> args;
  // The device context
  Var device_id("dev_id");
  Integer device_type(target_device_type);
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after init
  std::vector<Stmt> seq_init, seq_check;
  std::unordered_map<const VarNode*, PrimExpr> vmap;
  ArgBinder binder(&vmap);
  // ---------------------------
  // local function definitions
  // load i-th argument as type t
  auto f_arg_value = [&](DataType t, int i) {
    Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
                              IntImm(DataType::Int(32), builtin::kTVMValueContent)};
    // load 64 bit version
    DataType api_type = APIType(t);
    PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
    // cast to the target version.
    if (api_type != t) {
      res = Cast(t, res);
    }
    return res;
  };
  // ---------------------------
  // start of logics
  // add signiture for packed arguments.
  if (pack_args) {
    args.push_back(v_packed_args);
    args.push_back(v_packed_arg_type_ids);
    args.push_back(v_num_packed_args);
    std::ostringstream os;

    os << name_hint << ": num_args should be " << num_packed_args;
    seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
  }

  // Need to re-declare vars, in case some arguments also appears in the buffer.
  std::vector<std::pair<Var, Var> > var_def;
  std::vector<std::pair<Var, Buffer> > buffer_def;

  for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
    Var param = func_ptr->params[i];
    Var v_arg = Var("arg" + std::to_string(i), param->dtype);

    auto it = func_ptr->buffer_map.find(param);
    if (it != func_ptr->buffer_map.end()) {
      buffer_def.emplace_back(v_arg, (*it).second);
    } else {
      var_def.emplace_back(v_arg, param);
    }
    if (i < num_packed_args) {
      // Value loads
      seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
      // type code checks
      Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
      seq_init.emplace_back(LetStmt(tcode,
                                    Load(DataType::Int(32), v_packed_arg_type_ids,
                                         IntImm(DataType::Int(32), i), const_true(1)),
                                    nop));
      DataType t = v_arg.dtype();
      if (t.is_handle()) {
        std::ostringstream msg;
        msg << name_hint << ": Expect arg[" << i << "] to be pointer";
        seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
                                              tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
                                          tvm::tir::StringImm(msg.str()), nop));
      } else if (t.is_int() || t.is_uint()) {
        std::ostringstream msg;
        msg << name_hint << ": Expect arg[" << i << "] to be int";
        seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
      } else {
        ICHECK(t.is_float());
        std::ostringstream msg;
        msg << name_hint << ": Expect arg[" << i << "] to be float";
        seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
      }
    } else {
      args.push_back(v_arg);
    }
  }

  // allow return value if the function is packed.
  if (pack_args) {
    args.push_back(v_out_ret_value);
    args.push_back(v_out_ret_tcode);
    args.push_back(v_resource_handle);
  }

  size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
  ICHECK_EQ(args.size(), expected_nargs);

  // Arg definitions are defined before buffer binding to avoid the use before
  // def errors.
  //
  // For example, for auto broadcasting, checks are required to guarantee that
  // either 0 or the original stride will be correctly used. Checks here have
  // to use the args that may have no let binding yet. Therefore, hoisting let
  // binding for args before buffer declaration is needed.
  for (const auto& kv : var_def) {
    binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
  }

  for (const auto& kv : buffer_def) {
    binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
  }

  if (num_unpacked_args == 0) {
    func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
  }

  Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
  body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
                  StringImm(name_hint + "_compute_"), body);
  // Set device context
  if (vmap.count(device_id.get())) {
    PrimExpr node = StringImm("default");
    seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
    seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));

    if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
      Stmt set_device =
          Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
                        {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
      body = SeqStmt({set_device, body});
    }
  }
  func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
  func_ptr->params = args;

  Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
  if (undefined.size() != 0) {
    std::ostringstream os;
    for (Var v : undefined) {
      os << " \'" << v->name_hint << "\' ";
    }
    os << " is not bound to any variables";
    LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
  }

  func_ptr->buffer_map = Map<Var, Buffer>();
  func_ptr->checked_type_ = func_ptr->func_type_annotation();
  func_ptr->ret_type = PrimType(DataType::Int(32));

  // return the function.
  return std::move(func);
}