Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions codon/parser/ast/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace codon::ast {
void accept(VISITOR &visitor) override; \
std::string toString(int) const override; \
friend class TypecheckVisitor; \
friend class AutoDeduceMembersTypecheckVisitor; \
template <typename TE, typename TS> friend struct CallbackASTVisitor; \
friend struct ReplacingCallbackASTVisitor; \
inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \
Expand Down Expand Up @@ -51,6 +52,8 @@ struct Expr : public AcceptorExtend<Expr, ASTNode> {
void setDone() { done = true; }
Expr *getOrigExpr() const { return origExpr; }
void setOrigExpr(Expr *orig) { origExpr = orig; }
Expr *getTypeExpr() const { return typeExpr; }
void setTypeExpr(Expr *type) { typeExpr = type; }

static const char NodeId;
SERIALIZE(Expr, BASE(ASTNode), /*type,*/ done, origExpr);
Expand All @@ -69,6 +72,8 @@ struct Expr : public AcceptorExtend<Expr, ASTNode> {
bool done;
/// Original (pre-transformation) expression
Expr *origExpr;
/// the expression of type
Expr *typeExpr{nullptr};
};

/// Function signature parameter helper node (name: type = defaultValue).
Expand Down
1 change: 1 addition & 0 deletions codon/parser/ast/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace codon::ast {
void accept(VISITOR &visitor) override; \
std::string toString(int) const override; \
friend class TypecheckVisitor; \
friend class AutoDeduceMembersTypecheckVisitor; \
template <typename TE, typename TS> friend struct CallbackASTVisitor; \
friend struct ReplacingCallbackASTVisitor; \
inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \
Expand Down
72 changes: 67 additions & 5 deletions codon/parser/visitors/typecheck/class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,61 @@ std::vector<TypePtr> TypecheckVisitor::parseBaseClasses(
return asts;
}

Expr *TypecheckVisitor::inferMemberType(std::string member, FunctionStmt *f) {
if (f->items.empty())
return nullptr;
AutoDeduceMembersTypecheckVisitor v(ctx, f->items);
return inferMemberType(f->items[0].name, member, f->getSuite(), v);
}
Expr *TypecheckVisitor::inferMemberType(std::string self, std::string member,
Stmt *stmt,
AutoDeduceMembersTypecheckVisitor &v) {
Expr *rlt = nullptr;
if (auto suite = cast<SuiteStmt>(stmt)) {
for (auto *s : *suite)
if (auto typ = inferMemberType(self, member, s, v))
rlt = typ;
} else if (auto whileLoop = cast<WhileStmt>(stmt)) {
if (auto typ = inferMemberType(self, member, whileLoop->getSuite(), v))
rlt = typ;
if (auto typ = inferMemberType(self, member, whileLoop->getElse(), v))
rlt = typ;
} else if (auto forLoop = cast<ForStmt>(stmt)) {
if (auto typ = inferMemberType(self, member, forLoop->getSuite(), v))
rlt = typ;
if (auto typ = inferMemberType(self, member, forLoop->getElse(), v))
rlt = typ;
} else if (auto ifStmt = cast<IfStmt>(stmt)) {
if (auto typ = inferMemberType(self, member, ifStmt->getIf(), v))
rlt = typ;
if (auto typ = inferMemberType(self, member, ifStmt->getElse(), v))
rlt = typ;
} else if (auto matchStmt = cast<MatchStmt>(stmt)) {
for (auto &c : matchStmt->items)
if (auto typ = inferMemberType(self, member, c.getSuite(), v))
rlt = typ;
} else if (auto tryStmt = cast<TryStmt>(stmt)) {
if (auto typ = inferMemberType(self, member, tryStmt->getSuite(), v))
rlt = typ;
if (auto typ = inferMemberType(self, member, tryStmt->getElse(), v))
rlt = typ;
if (auto typ = inferMemberType(self, member, tryStmt->getFinally(), v))
rlt = typ;
} else if (auto assignStmt = cast<AssignStmt>(stmt)) {
auto rhs = clone(assignStmt->getRhs());
rhs->accept(v);
if (auto lhs = cast<DotExpr>(assignStmt->getLhs())) {
if (auto idExpr = cast<IdExpr>(lhs->getExpr())) {
if (idExpr->getValue() == self && lhs->getMember() == member)
rlt = rhs->getTypeExpr();
}
} else if (auto lhs = cast<IdExpr>(assignStmt->getLhs())) {
v.addVar(lhs->value, rhs->getTypeExpr());
}
}
return rlt;
}

/// Find the first __init__ with self parameter and use it to deduce class members.
/// Each deduced member will be treated as generic.
/// @example
Expand All @@ -493,24 +548,31 @@ std::vector<TypePtr> TypecheckVisitor::parseBaseClasses(
/// @return the transformed init and the pointer to the original function.
bool TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector<Param> &args) {
std::set<std::string> members;
std::unordered_map<std::string, Expr *> member2type;
for (const auto &sp : getClassMethods(stmt->suite))
if (auto f = cast<FunctionStmt>(sp)) {
if (f->name == "__init__")
if (const auto b =
f->getAttribute<ir::StringListAttribute>(Attr::ClassDeduce)) {
for (const auto &m : b->values)
for (const auto &m : b->values) {
members.insert(m);
member2type[m] = inferMemberType(m, f);
}
}
}
if (!members.empty()) {
// log("auto-deducing {}: {}", stmt->name, members);
if (auto aa = stmt->getAttribute<ir::StringListAttribute>(Attr::ClassMagic))
std::erase(aa->values, "init");
for (auto m : members) {
auto genericName = fmt::format("T_{}", m);
args.emplace_back(genericName, N<IdExpr>(TYPE_TYPE), N<IdExpr>("NoneType"),
Param::Generic);
args.emplace_back(m, N<IdExpr>(genericName));
if (auto typ = member2type[m]) {
args.emplace_back(m, typ);
} else {
auto genericName = fmt::format("T_{}", m);
args.emplace_back(genericName, N<IdExpr>(TYPE_TYPE), N<IdExpr>("NoneType"),
Param::Generic);
args.emplace_back(m, N<IdExpr>(genericName));
}
}
return true;
}
Expand Down
136 changes: 136 additions & 0 deletions codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,4 +1924,140 @@ ir::PyFunction TypecheckVisitor::cythonizeFunction(const std::string &name) {
return {"", ""};
}

void AutoDeduceMembersTypecheckVisitor::visit(BoolExpr *exp) {
exp->setTypeExpr(N<IdExpr>("bool"));
}

void AutoDeduceMembersTypecheckVisitor::visit(IntExpr *exp) {
exp->setTypeExpr(N<IdExpr>("int"));
}

void AutoDeduceMembersTypecheckVisitor::visit(FloatExpr *exp) {
exp->setTypeExpr(N<IdExpr>("float"));
}

void AutoDeduceMembersTypecheckVisitor::visit(StringExpr *exp) {
exp->setTypeExpr(N<IdExpr>("str"));
}

void AutoDeduceMembersTypecheckVisitor::visit(IdExpr *exp) {
auto val = exp->getValue();
auto it = std::find_if(args.begin(), args.end(),
[val](Param &arg) { return arg.name == val; });
if (it != args.end()) {
exp->setTypeExpr(it->getType());
}
}

void AutoDeduceMembersTypecheckVisitor::visit(TupleExpr *exp) {
std::vector<Expr *> items;
for (auto e : exp->items) {
e->accept(*this);
items.push_back(e->getTypeExpr());
}
auto tupleExpr = N<TupleExpr>(items);
auto idExpr = N<IdExpr>("Tuple");
exp->setTypeExpr(N<IndexExpr>(idExpr, tupleExpr));
}

void AutoDeduceMembersTypecheckVisitor::visit(ListExpr *exp) {
if (exp->items.empty())
return;
exp->items[0]->accept(*this);
auto listExpr = N<IdExpr>("List");
exp->setTypeExpr(N<IndexExpr>(listExpr, exp->items[0]->getTypeExpr()));
}

void AutoDeduceMembersTypecheckVisitor::visit(SetExpr *exp) {
if (exp->items.empty())
return;
exp->items[0]->accept(*this);
auto setExpr = N<IdExpr>("set");
exp->setTypeExpr(N<IndexExpr>(setExpr, exp->items[0]->getTypeExpr()));
}

void AutoDeduceMembersTypecheckVisitor::visit(DictExpr *exp) {
std::vector<Expr *> items;
for (auto e : exp->items) {
e->accept(*this);
items.push_back(e->getTypeExpr());
}
auto tupleExpr = N<TupleExpr>(items);
auto dictExpr = N<IdExpr>("Dict");
exp->setTypeExpr(N<IndexExpr>(dictExpr, tupleExpr));
}

void AutoDeduceMembersTypecheckVisitor::visit(UnaryExpr *exp) {
if (exp->op == "not") {
exp->setTypeExpr(N<IdExpr>("bool"));
} else if (exp->op == "+" || exp->op == "-" || exp->op == "~") {
exp->expr->accept(*this);
exp->setTypeExpr(exp->expr->getTypeExpr());
} else {
exp->setTypeExpr(N<IdExpr>(exp->op));
}
}

void AutoDeduceMembersTypecheckVisitor::visit(BinaryExpr *exp) {
exp->lexpr->accept(*this);
exp->rexpr->accept(*this);
exp->setTypeExpr(
mergeTypeExpr(exp->op, exp->lexpr->getTypeExpr(), exp->rexpr->getTypeExpr()));
}

void AutoDeduceMembersTypecheckVisitor::visit(RangeExpr *exp) {
exp->setTypeExpr(N<IdExpr>("range"));
}

void AutoDeduceMembersTypecheckVisitor::visit(GeneratorExpr *exp) {
if (exp->kind == GeneratorExpr::ListGenerator ||
exp->kind == GeneratorExpr::SetGenerator) {
if (auto forExpr = cast<ForStmt>(exp->loops)) {
if (auto iter = cast<CallExpr>(forExpr->getIter())) {
if (auto idExpr = cast<IdExpr>(iter->expr)) {
if (idExpr->value == "range") {
if (forExpr->getSuite()->items.size() == 1) {
if (auto expStm = cast<ExprStmt>(forExpr->getSuite()->items[0])) {
if (auto varExpr = cast<IdExpr>(forExpr->getVar())) {
addVar(varExpr->value, N<IdExpr>("int"));
expStm->expr->accept(*this);
auto typ = exp->kind == GeneratorExpr::ListGenerator
? N<IdExpr>("List")
: N<IdExpr>("set");
exp->setTypeExpr(N<IndexExpr>(typ, expStm->expr->getTypeExpr()));
}
}
}
}
}
}
}
}
}

Expr *AutoDeduceMembersTypecheckVisitor::mergeTypeExpr(std::string op, Expr *l,
Expr *r) {
if (l == nullptr || r == nullptr) {
return nullptr;
} else if (op == "==" || op == "!=" || op == ">" || op == ">=" || op == "<" ||
op == "<=" || op == "is" || op == "is not" || op == "in" ||
op == "not in" || op == "and" || op == "or") {
return N<IdExpr>("bool");
} else if (op == "&" || op == "|" || op == "^" || op == "<<" || op == ">>" ||
op == "//") {
return N<IdExpr>("int");
} else if (op == "/") {
return N<IdExpr>("float");
} else if (op == "+" || op == "-" || op == "*" || op == "%" || op == "**") {
if (l->toString() == r->toString()) {
return l;
} else if (l->toString() == "'float" || r->toString() == "'float") {
return N<IdExpr>("float");
} else {
return N<IdExpr>("int");
}
}
return l;
}

} // namespace codon::ast
35 changes: 35 additions & 0 deletions codon/parser/visitors/typecheck/typecheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

namespace codon::ast {

class AutoDeduceMembersTypecheckVisitor;

/**
* Visitor that infers expression types and performs type-guided transformations.
*
Expand Down Expand Up @@ -201,6 +203,8 @@ class TypecheckVisitor : public ReplacingCallbackASTVisitor {
const std::string &, const Expr *,
types::ClassType *);
bool autoDeduceMembers(ClassStmt *, std::vector<Param> &);
Expr *inferMemberType(std::string, FunctionStmt *);
Expr *inferMemberType(std::string, std::string, Stmt *, AutoDeduceMembersTypecheckVisitor &);
static std::vector<Stmt *> getClassMethods(Stmt *s);
void transformNestedClasses(const ClassStmt *, std::vector<Stmt *> &,
std::vector<Stmt *> &, std::vector<Stmt *> &);
Expand Down Expand Up @@ -484,4 +488,35 @@ class TypecheckVisitor : public ReplacingCallbackASTVisitor {
// types::Type *getType(const std::string &);
};

// A simpler typechecker to infer the member type in advance
// based on the initializing right-hand side values.
// TODO: support method calls.
class AutoDeduceMembersTypecheckVisitor : public ASTVisitor {
public:
AutoDeduceMembersTypecheckVisitor(std::shared_ptr<TypeContext> ctx, std::vector<Param> &args)
: ctx(ctx), args(args) {}
void addVar(std::string name, Expr *typ) { args.emplace_back(name, typ); }
private:
template <typename Tn, typename... Ts> Tn *N(Ts &&...args) {
Tn *t = ctx->cache->N<Tn>(std::forward<Ts>(args)...);
return t;
}
std::shared_ptr<TypeContext> ctx;
std::vector<Param> args;
void visit(BoolExpr *) override;
void visit(IntExpr *) override;
void visit(FloatExpr *) override;
void visit(StringExpr *) override;
void visit(IdExpr *) override;
void visit(TupleExpr *) override;
void visit(ListExpr *) override;
void visit(SetExpr *) override;
void visit(DictExpr *) override;
void visit(UnaryExpr *) override;
void visit(BinaryExpr *) override;
void visit(RangeExpr *) override;
void visit(GeneratorExpr *) override;
Expr *mergeTypeExpr(std::string, Expr *, Expr *);
};

} // namespace codon::ast