Skip to content

Commit c234f87

Browse files
arshajiiinumanag
andauthored
GPU codegen fixes (#746)
* GPU codegen fixes * Fix parseCode() API * Fix parseCode * Import gpu on @Par(gpu) * Mark GPU functions as nothrow --------- Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
1 parent db10a87 commit c234f87

8 files changed

Lines changed: 47 additions & 17 deletions

File tree

codon/cir/llvm/gpu.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,12 @@ void linkLibdevice(llvm::Module *M, const std::string &path) {
187187
seqassertn(!fail, "linking libdevice failed");
188188
}
189189

190-
llvm::Function *copyPrototype(llvm::Function *F, const std::string &name) {
190+
llvm::Function *copyPrototype(llvm::Function *F, const std::string &name,
191+
bool external = false) {
191192
auto *M = F->getParent();
192-
return llvm::Function::Create(F->getFunctionType(), llvm::GlobalValue::PrivateLinkage,
193+
return llvm::Function::Create(F->getFunctionType(),
194+
external ? llvm::GlobalValue::ExternalLinkage
195+
: llvm::GlobalValue::PrivateLinkage,
193196
name.empty() ? F->getName() : name, *M);
194197
}
195198

@@ -651,7 +654,7 @@ void remapFunctions(llvm::Module *M) {
651654
} else {
652655
G = M->getFunction(pair.second);
653656
if (!G)
654-
G = copyPrototype(F, pair.second);
657+
G = copyPrototype(F, pair.second, /*external=*/true);
655658
}
656659

657660
G->setWillReturn();
@@ -712,6 +715,12 @@ std::string moduleToPTX(llvm::Module *M, std::vector<llvm::GlobalValue *> &kerne
712715
llvm::codegen::getExplicitRelocModel(), llvm::codegen::getExplicitCodeModel(),
713716
llvm::CodeGenOptLevel::Aggressive));
714717

718+
// Remove personality functions
719+
for (auto &F : *M) {
720+
F.setDoesNotThrow();
721+
F.setPersonalityFn(nullptr);
722+
}
723+
715724
M->setDataLayout(machine->createDataLayout());
716725
auto keep = getRequiredGVs(kernels);
717726

codon/cir/transform/parallel/openmp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,7 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
15601560

15611561
if (sched->gpu) {
15621562
std::unordered_set<id_t> kernels;
1563-
const std::string gpuAttr = ast::getMangledFunc("std.internal.gpu", "kernel");
1563+
const std::string gpuAttr = ast::getMangledFunc(gpuModule, "kernel");
15641564
for (auto *var : *M) {
15651565
if (auto *func = cast<BodiedFunc>(var)) {
15661566
if (util::hasAttribute(func, gpuAttr)) {

codon/parser/cache.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,18 @@ std::vector<size_t> Cache::getChildRealizationIds(types::ClassType *type) {
208208
return childIds;
209209
}
210210

211-
void Cache::parseCode(const std::string &code) {
211+
std::vector<ir::SeriesFlow *> Cache::parseCode(const std::string &code) {
212212
auto nodeOrErr = ast::parseCode(this, "<internal>", code, /*startLine=*/0);
213-
if (nodeOrErr)
213+
if (!nodeOrErr)
214214
throw exc::ParserException(nodeOrErr.takeError());
215215
auto sctx = imports[MAIN_IMPORT].ctx;
216216
auto node = ast::TypecheckVisitor::apply(sctx, *nodeOrErr);
217+
auto old = codegenCtx->series;
218+
codegenCtx->series.clear();
217219
ast::TranslateVisitor(codegenCtx).initializeGlobals();
218220
ast::TranslateVisitor(codegenCtx).translateStmts(node);
221+
std::swap(old, codegenCtx->series);
222+
return old;
219223
}
220224

221225
std::vector<std::shared_ptr<types::ClassType>>

codon/parser/cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ struct Cache {
331331
std::vector<size_t> getBaseRealizationIds(types::ClassType *type);
332332
std::vector<size_t> getChildRealizationIds(types::ClassType *type);
333333

334-
void parseCode(const std::string &code);
334+
std::vector<ir::SeriesFlow *> parseCode(const std::string &code);
335335

336336
static std::vector<std::shared_ptr<types::ClassType>>
337337
mergeC3(std::vector<std::vector<types::TypePtr>> &);

codon/parser/visitors/typecheck/loops.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ void TypecheckVisitor::visit(WhileStmt *stmt) {
102102
void TypecheckVisitor::visit(ForStmt *stmt) {
103103
stmt->decorator = transformForDecorator(stmt->getDecorator());
104104

105+
if (auto fc = cast<CallExpr>(stmt->getDecorator())) {
106+
if (auto fi = cast<IdExpr>(fc->getExpr());
107+
fi && fi->getType()->getFunc() &&
108+
fi->getType()->getFunc()->getFuncName() ==
109+
getMangledFunc("std.openmp", "for_par")) {
110+
if (auto n = extractFuncGeneric(fi->getType(), 3)->getBoolStatic();
111+
n && n->value) {
112+
prependStmts->push_back(
113+
transform(N<ImportStmt>(N<IdExpr>("gpu"), nullptr, std::vector<Param>{},
114+
nullptr, getTemporaryVar("_"))));
115+
}
116+
}
117+
}
118+
105119
std::string breakVar;
106120
// Needs in-advance transformation to prevent name clashes with the iterator variable
107121
stmt->getIter()->setAttribute(

stdlib/internal/gpu.codon

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,10 @@ def nvptx_load_module():
138138
module = CUmodule()
139139
# NOTE: 2nd argument to cuModuleLoadData() will
140140
# be replaced by Codon's GPU pass
141-
cuda_check(cuModuleLoadData(__ptr__(module), __codon_ptx__()))
142-
modules.append(module)
141+
ptx = __codon_ptx__()
142+
if ptx:
143+
cuda_check(cuModuleLoadData(__ptr__(module), ptx))
144+
modules.append(module)
143145

144146

145147
def cuda_init(debug: bool = False):

stdlib/internal/types/str.codon

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,17 @@ class str:
144144
n += s.len
145145
return str(p, n)
146146
else:
147-
total = 0
148-
for i in args:
149-
if not isinstance(i, str):
147+
for i in static.range(static.len(args)):
148+
s = args[i]
149+
if not isinstance(s, str):
150150
compile_error("not a string")
151-
total += i.len
151+
total += s.len
152152
p = cobj(total)
153153
n = 0
154-
for i in args:
155-
str.memcpy(p + n, i.ptr, i.len)
156-
n += i.len
154+
for i in static.range(static.len(args)):
155+
s = args[i]
156+
str.memcpy(p + n, s.ptr, s.len)
157+
n += s.len
157158
return str(p, total)
158159

159160
def __prefix_b__(s: str, N: Literal[int]):

stdlib/numpy/ndarray.codon

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ class ndarray[dtype, ndim: Literal[int]]:
10801080

10811081
def _check_order(order: str):
10821082
if order not in ('C', 'F', 'A', 'K'):
1083-
raise ValueError(f"order must be one of 'C', 'F', 'A', or 'K' (got {repr(order)})")
1083+
raise ValueError(f"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')")
10841084

10851085
def astype(self, dtype: type, order: str = 'K', copy: bool = True):
10861086
ndarray._check_order(order)

0 commit comments

Comments
 (0)