Skip to content

Commit 12188f8

Browse files
committed
Enforce tuples for __ptr__
1 parent 09508a0 commit 12188f8

4 files changed

Lines changed: 97 additions & 22 deletions

File tree

codon/compiler/error.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ template <class... TA> std::string Emsg(Error e, const TA &...args) {
419419
case Error::CALL_SUPER_PARENT:
420420
return fmt::format("no super methods found");
421421
case Error::CALL_PTR_VAR:
422-
return fmt::format("__ptr__() only takes identifiers as arguments");
422+
return fmt::format("__ptr__() only takes identifiers or tuple fields as arguments");
423423
case Error::EXPECTED_TUPLE:
424424
return fmt::format("expected tuple type");
425425
case Error::CALL_REALIZED_FN:

codon/parser/visitors/translate/translate.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,22 +276,28 @@ void TranslateVisitor::visit(CallExpr *expr) {
276276
auto ei = cast<IdExpr>(expr->getExpr());
277277
if (ei && ei->getValue() == getMangledFunc("std.internal.core", "__ptr__")) {
278278
auto head = expr->begin()->getExpr();
279+
ir::FlowInstr *pre = cast<ir::FlowInstr>(transform(head));
279280
while (auto sexp = cast<StmtExpr>(head))
280281
head = sexp->getExpr();
281282

282-
std::string member;
283-
if (auto id = cast<DotExpr>(head)) {
283+
std::vector<std::string> members;
284+
while (auto id = cast<DotExpr>(head)) {
285+
members.emplace_back(id->getMember());
284286
head = id->getExpr();
285-
member = id->getMember();
286287
}
288+
std::ranges::reverse(members);
287289
auto id = cast<IdExpr>(head);
288290
seqassert(id, "expected IdExpr, got {}", *((*expr)[0].value));
289291
auto key = id->getValue();
290292
auto val = ctx->find(key);
291293
seqassert(val && val->getVar(), "{} is not a variable", key);
292-
result = make<ir::PointerValue>(expr, val->getVar(),
293-
member.empty() ? std::vector<std::string>{}
294-
: std::vector<std::string>{member});
294+
295+
auto pv = make<ir::PointerValue>(expr, val->getVar(), members);
296+
if (pre) {
297+
pre->setValue(pv);
298+
} else {
299+
result = pv;
300+
}
295301
return;
296302
} else if (ei && ei->getValue() ==
297303
getMangledMethod("std.internal.core", "__array__", "__new__")) {

codon/parser/visitors/typecheck/special.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -498,23 +498,40 @@ Expr *TypecheckVisitor::transformPtr(CallExpr *expr) {
498498
expr->begin()->value = transform(expr->begin()->getExpr());
499499

500500
auto head = getHeadExpr(expr->begin()->getExpr());
501-
auto id = cast<IdExpr>(head);
502-
std::string member;
503-
if (auto dot = cast<DotExpr>(head)) {
504-
if (((id = cast<IdExpr>(dot->getExpr())))) {
505-
member = dot->getMember();
501+
std::vector<std::string> members;
502+
for (bool last = true;; last = false) {
503+
auto t = extractClassType(head);
504+
if (!t)
505+
return nullptr;
506+
if (!t->isRecord())
507+
E(Error::CALL_PTR_VAR, expr->begin()->getExpr());
508+
509+
if (auto id = cast<IdExpr>(head)) {
510+
auto val = id ? ctx->find(id->getValue(), getTime()) : nullptr;
511+
if (!val || !val->isVar()) {
512+
E(Error::CALL_PTR_VAR, expr->begin()->getExpr());
513+
}
514+
break;
515+
} else if (auto dot = cast<DotExpr>(head)) {
516+
if (last && !t->isRecord()) {
517+
E(Error::CALL_PTR_VAR, expr->begin()->getExpr());
518+
} else if (!t->isRecord()) {
519+
auto tmp = getTemporaryVar("ptr");
520+
auto newDot = N<DotExpr>(N<IdExpr>(tmp), dot->getMember());
521+
std::ranges::reverse(members);
522+
for (auto &m : members)
523+
newDot = N<DotExpr>(newDot, m);
524+
return transform(N<StmtExpr>(
525+
N<AssignStmt>(N<IdExpr>(tmp), dot->getExpr()),
526+
N<CallExpr>(N<IdExpr>(getMangledFunc("std.internal.core", "__ptr__")),
527+
newDot)));
528+
}
529+
head = dot->getExpr();
506530
} else {
507-
auto tmp = getTemporaryVar("ptr");
508-
return transform(N<StmtExpr>(
509-
N<AssignStmt>(N<IdExpr>(tmp), dot->getExpr()),
510-
N<CallExpr>(N<IdExpr>(getMangledFunc("std.internal.core", "__ptr__")),
511-
N<DotExpr>(N<IdExpr>(tmp), dot->getMember()))));
531+
E(Error::CALL_PTR_VAR, expr->begin()->getExpr());
532+
break;
512533
}
513534
}
514-
auto val = id ? ctx->find(id->getValue(), getTime()) : nullptr;
515-
if (!val || !val->isVar()) {
516-
E(Error::CALL_PTR_VAR, expr->begin()->getExpr());
517-
}
518535

519536
unify(expr->getType(),
520537
instantiateType(getStdLibType("Ptr"), {expr->begin()->getExpr()->getType()}));

test/parser/typecheck/test_call.codon

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,65 @@ def test_class():
4141
# test_class()
4242

4343

44+
@tuple
45+
class A:
46+
n: int
47+
48+
@tuple
49+
class B:
50+
a: A
51+
52+
@tuple
53+
class C:
54+
b: B
55+
56+
x = C(B(A(1)))
57+
p = __ptr__(x.b.a.n)
58+
p[0] = 55
59+
print(x) #: (b: (a: (n: 55)))
60+
61+
4462
#%% call_ptr_error,barebones
45-
__ptr__(1) #! __ptr__() only takes identifiers as arguments
63+
__ptr__(1) #! __ptr__() only takes identifiers or tuple fields as arguments
64+
65+
#%% call_ptr_error_2,barebones
66+
__ptr__([1]) #! __ptr__() only takes identifiers or tuple fields as arguments
4667

4768
#%% call_ptr_error_3,barebones
4869
v = 1
4970
__ptr__(v, 1) #! __ptr__() takes 1 arguments (2 given)
5071

72+
#%% call_ptr_error_4,barebones
73+
@tuple
74+
class A:
75+
n: int
76+
77+
class B:
78+
a: A
79+
80+
@tuple
81+
class C:
82+
b: B
83+
84+
x = C(B(A(1)))
85+
print(__ptr__(x.b.a.n)) #! __ptr__() only takes identifiers or tuple fields as arguments
86+
87+
#%% call_ptr_error_5,barebones
88+
@tuple
89+
class A:
90+
n: int
91+
92+
class B:
93+
a: A
94+
95+
@tuple
96+
class C:
97+
b: B
98+
99+
x = C(B(A(1)))
100+
print(__ptr__(A(1).n)) #! __ptr__() only takes identifiers or tuple fields as arguments
101+
102+
51103
#%% call_array,barebones
52104
a = __array__[int](2)
53105
a[0] = a[1] = 5

0 commit comments

Comments
 (0)