WASM Code Generation Debug Note
Analysis
TVM Build
This Python interface checks the validity of the input and invokes codegen.build_module
codegen.build_module
def build_module(mod, target):
"""Build IRModule into Module.
Parameters
----------
mod : tvm.IRModule
The ir module.
target : str
The target module type.
Returns
-------
module : runtime.Module
The corressponding module.
"""
target = Target(target) if isinstance(target, str) else target
return _ffi_api.Build(mod, target)
FFI
ffi
is a very mysterious part in TVM. The _ffi_api
is:
"""FFI APIs for tvm.target"""
import tvm._ffi
tvm._ffi._init_api("target", __name__)
Then:
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if not name.startswith(prefix):
continue
fname = name[len(prefix) + 1 :]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = "TVM PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
It seems the FFI part in TVM will automatically collect all exported FFIs.
So I then looked at the related module.
src/target/codegen.cc
There is a Build
method which is also registered:
runtime::Module Build(IRModule mod, Target target) {
if (transform::PassContext::Current()
->GetConfig<Bool>("tir.disable_assert", Bool(false))
.value()) {
mod = tir::transform::SkipAssert()(mod);
}
// the build function.
std::string build_f_name = "target.build." + target->kind->name;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
ICHECK(bf != nullptr) << build_f_name << " is not enabled";
return (*bf)(mod, target);
}
// Some code here
TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);
Here this function calls target.build.llvm
indirectly in my case.
src/target/llvm/llvm_module.cc
Just like the previous one, the function has been registered:
TVM_REGISTER_GLOBAL("target.build.llvm")
.set_body_typed([](IRModule mod, Target target) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
});
And this function invokes LLVMModuleNode::Init
At the beginning it does some parameter checks.
The function here invokes CodeGenLLVM::Init
, which initializes the code generator for LLVM.
The core functions for generating functions in LLVM are:
void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
alias_var_set_.clear();
alloc_storage_info_.clear();
volatile_buf_.clear();
analyzer_.reset(new arith::Analyzer());
}
void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
this->InitFuncState();
ICHECK_EQ(f->buffer_map.size(), 0U)
<< "Cannot codegen function with buffer_map, please lower them first";
std::vector<llvm::Type*> param_types;
is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
for (Var param : f->params) {
param_types.push_back(GetLLVMType(param));
if (!is_restricted_ && param.dtype().is_handle()) {
alias_var_set_.insert(param.get());
}
}
// TODO(tvm-team):
// Update the function type to respect the ret_type field of f.
// Once we allow more flexibility in the PrimFunc.
llvm::FunctionType* ftype =
llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
ICHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
<< "Function " << global_symbol << " already exist in module";
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
global_symbol.value().operator std::string(), module_.get());
function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
// set var map and align information
auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
llvm::Argument* v = &(*arg_it);
const Var& var = f->params[i];
var_map_[var.get()] = v;
if (is_restricted_) {
if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
function_->addParamAttr(i, llvm::Attribute::NoAlias);
#else
function_->setDoesNotAlias(i + 1);
#endif
}
}
}
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry);
this->VisitStmt(f->body);
// Add alignment attribute if needed.
#if TVM_LLVM_VERSION >= 50
for (size_t i = 0; i < f->params.size(); ++i) {
const Var& var = f->params[i];
auto f = alloc_storage_info_.find(var.get());
if (f != alloc_storage_info_.end()) {
unsigned align = f->second.alignment;
if (align > 1) {
auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
function_->addParamAttr(i, attr);
}
}
}
#endif
llvm::StringRef fs = target_machine_->getTargetFeatureString();
if (!fs.empty()) {
function_->addFnAttr("target-features", fs);
}
if (ret_void) {
builder_->CreateRetVoid();
} else {
builder_->CreateRet(ConstInt32(0));
}
}
However, there is no function that have the same signature as BackendPackedCFunc
.
Tir
Level Transformation
After searching for a while, I finally found a place for BackendPackedCFunc
generation. It invoked in the Python script: mod_host, mdev = _build_for_device(input_mod, tar, target_host)
:
def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
input_mod : IRModule
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : IRModule
The host IRModule.
mdev : tvm.module
A module that contains device code.
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
device_type = ndarray.device(target.kind.name, 0).device_type
mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
opt_mixed = [tvm.tir.transform.VerifyMemory()]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if PassContext.current().config.get("tir.detect_global_barrier", False):
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [
tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.SplitHostDevice(),
]
mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
# device optimizations
opt_device = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs
and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
]
)
mod_dev = opt_device(mod_mixed)
# host optimizations
opt_host = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall(),
]
)
mod_host = opt_host(mod_mixed)
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do " "bind?" % target
)
rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev
Here the function MakePackedAPI
generates the function with BackendPackedCFunc
signature.
PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
int target_device_type = target.value()->kind->device_type;
std::string name_hint = global_symbol.value();
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
ICHECK_LE(num_unpacked_args, num_args);
bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args);
if (num_unpacked_args == -1) {
// reset to zero
num_unpacked_args = 0;
}
ICHECK_GE(num_unpacked_args, 0);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
Var v_packed_args("args", DataType::Handle());
Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
Var v_num_packed_args("num_args", DataType::Int(32));
Var v_out_ret_value("out_ret_value", DataType::Handle());
Var v_out_ret_tcode("out_ret_tcode", DataType::Handle());
Var v_resource_handle("resource_handle", DataType::Handle());
// The arguments of the function.
Array<Var> args;
// The device context
Var device_id("dev_id");
Integer device_type(target_device_type);
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check;
std::unordered_map<const VarNode*, PrimExpr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
}
return res;
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;
os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
// Need to re-declare vars, in case some arguments also appears in the buffer.
std::vector<std::pair<Var, Var> > var_def;
std::vector<std::pair<Var, Buffer> > buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);
auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
} else {
var_def.emplace_back(v_arg, param);
}
if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(tcode,
Load(DataType::Int(32), v_packed_arg_type_ids,
IntImm(DataType::Int(32), i), const_true(1)),
nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
} else {
args.push_back(v_arg);
}
}
// allow return value if the function is packed.
if (pack_args) {
args.push_back(v_out_ret_value);
args.push_back(v_out_ret_tcode);
args.push_back(v_resource_handle);
}
size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0);
ICHECK_EQ(args.size(), expected_nargs);
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
//
// For example, for auto broadcasting, checks are required to guarantee that
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let binding yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& kv : var_def) {
binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
}
for (const auto& kv : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
}
if (num_unpacked_args == 0) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImm("default");
seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " is not bound to any variables";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return std::move(func);
}