commit b8a480fa2cf95b6df2eea64c491aed14ab6ac69d
parent 081c9ebae45ce57b874f18b64e3399f92e456442
Author: Vlad-Stefan Harbuz <vlad@vladh.net>
Date: Sat, 4 Jun 2022 11:51:23 +0100
add tuple unpacking for bindings
Signed-off-by: Vlad-Stefan Harbuz <vlad@vladh.net>
Diffstat:
M | include/ast.h | | | 6 | ++++++ |
M | include/expr.h | | | 7 | +++++++ |
M | src/check.c | | | 107 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
M | src/gen.c | | | 84 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
M | src/parse.c | | | 50 | ++++++++++++++++++++++++++++++++++++++++++++++++-- |
M | tests/21-tuples.ha | | | 108 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- |
6 files changed, 359 insertions(+), 3 deletions(-)
diff --git a/include/ast.h b/include/ast.h
@@ -154,8 +154,14 @@ struct ast_expression_binarithm {
struct ast_expression *lvalue, *rvalue;
};
+struct ast_binding_unpack {
+ char *name;
+ struct ast_binding_unpack *next;
+};
+
struct ast_expression_binding {
char *name;
+ struct ast_binding_unpack *unpack;
struct ast_type *type;
unsigned int flags;
bool is_static;
diff --git a/include/expr.h b/include/expr.h
@@ -134,8 +134,15 @@ struct expression_binarithm {
struct expression *lvalue, *rvalue;
};
+struct binding_unpack {
+ const struct scope_object *object;
+ size_t offset;
+ struct binding_unpack *next;
+};
+
struct expression_binding {
const struct scope_object *object;
+ struct binding_unpack *unpack;
struct expression *initializer;
struct expression_binding *next;
};
diff --git a/src/check.c b/src/check.c
@@ -985,6 +985,106 @@ check_expr_binarithm(struct context *ctx,
}
static void
+check_binding_unpack(struct context *ctx,
+ const struct type *type,
+ const struct ast_expression_binding *abinding,
+ struct expression_binding *binding,
+ const struct ast_expression *aexpr,
+ struct expression *expr)
+{
+ assert(abinding->unpack);
+ const struct ast_binding_unpack *cur = abinding->unpack;
+ binding->unpack = xcalloc(1, sizeof(struct binding_unpack));
+ struct binding_unpack *unpack = binding->unpack;
+
+ struct expression *initializer = xcalloc(1, sizeof(struct expression));
+ check_expression(ctx, abinding->initializer, initializer, type);
+ if (initializer->result->storage != STORAGE_TUPLE) {
+ error(ctx, aexpr->loc, expr, "Could not unpack non-tuple type");
+ return;
+ }
+
+ if (!type) {
+ type = type_store_lookup_with_flags(
+ ctx->store, initializer->result, abinding->flags);
+ }
+
+ binding->initializer = lower_implicit_cast(type, initializer);
+
+ if (abinding->is_static) {
+ struct expression *value = xcalloc(1, sizeof(struct expression));
+ enum eval_result r = eval_expr(ctx, binding->initializer, value);
+ if (r != EVAL_OK) {
+ error(ctx, abinding->initializer->loc,
+ expr,
+ "Unable to evaluate static initializer at compile time");
+ return;
+ }
+ // TODO: Free initializer
+ binding->initializer = value;
+ assert(binding->initializer->type == EXPR_CONSTANT);
+ }
+
+ const struct type_tuple *type_tuple = &type->tuple;
+ bool found_binding = false;
+ while (cur) {
+ if (type_tuple->type->storage == STORAGE_NULL) {
+ error(ctx, aexpr->loc, expr,
+ "Null is not a valid type for a binding");
+ return;
+ }
+
+ if (cur->name) {
+ struct identifier ident = {
+ .name = cur->name,
+ };
+
+ if (abinding->is_static) {
+ struct identifier gen = {0};
+
+ // Generate a static declaration identifier
+ int n = snprintf(NULL, 0, "static.%d", ctx->id);
+ gen.name = xcalloc(n + 1, 1);
+ snprintf(gen.name, n + 1, "static.%d", ctx->id);
+ ++ctx->id;
+
+ unpack->object = scope_insert(
+ ctx->scope, O_DECL, &gen, &ident,
+ type_tuple->type, NULL);
+ } else {
+ unpack->object = scope_insert(
+ ctx->scope, O_BIND, &ident, &ident,
+ type_tuple->type, NULL);
+ }
+
+ unpack->offset = type_tuple->offset;
+
+ found_binding = true;
+ }
+
+ cur = cur->next;
+ type_tuple = type_tuple->next;
+
+ if (cur && found_binding && cur->name) {
+ unpack->next = xcalloc(1, sizeof(struct binding_unpack));
+ unpack = unpack->next;
+ }
+ }
+
+ if (!found_binding) {
+ error(ctx, aexpr->loc, expr,
+ "Must have at least one non-underscore value when unpacking tuples");
+ return;
+ }
+
+ if (type_tuple) {
+ error(ctx, aexpr->loc, expr,
+ "Fewer bindings than tuple elements were provided when unpacking");
+ return;
+ }
+}
+
+static void
check_expr_binding(struct context *ctx,
const struct ast_expression *aexpr,
struct expression *expr,
@@ -1006,6 +1106,12 @@ check_expr_binding(struct context *ctx,
type, type->flags | abinding->flags);
}
+ if (abinding->unpack) {
+ check_binding_unpack(ctx, type, abinding, binding,
+ aexpr, expr);
+ goto done;
+ }
+
struct identifier ident = {
.name = abinding->name,
};
@@ -1105,6 +1211,7 @@ check_expr_binding(struct context *ctx,
binding->initializer = value;
}
+done:
if (abinding->next) {
binding = *next =
xcalloc(1, sizeof(struct expression_binding));
diff --git a/src/gen.c b/src/gen.c
@@ -986,11 +986,95 @@ gen_expr_binarithm(struct gen_context *ctx, const struct expression *expr)
return result;
}
+static void
+gen_expr_binding_unpack_static(struct gen_context *ctx,
+ const struct expression_binding *binding)
+{
+ assert(binding->object == NULL);
+
+ struct tuple_constant *tupleconst =
+ binding->initializer->constant.tuple;
+
+ for (const struct binding_unpack *unpack = binding->unpack;
+ unpack; unpack = unpack->next) {
+ if (unpack->object == NULL) {
+ goto done;
+ }
+ assert(unpack->object->otype == O_DECL);
+
+ struct declaration decl = {
+ .type = DECL_GLOBAL,
+ .ident = unpack->object->ident,
+ .global = {
+ .type = unpack->object->type,
+ .value = tupleconst->value,
+ },
+ };
+ gen_global_decl(ctx, &decl);
+
+done:
+ tupleconst = tupleconst->next;
+ }
+}
+
+static void
+gen_expr_binding_unpack(struct gen_context *ctx,
+ const struct expression_binding *binding)
+{
+ assert(binding->object == NULL);
+
+ const struct type *type = binding->initializer->result;
+ char *tuple_name = gen_name(ctx, "tupleunpack.%d");
+ struct gen_value tuple_gv = {
+ .kind = GV_TEMP,
+ .type = type,
+ .name = tuple_name,
+ };
+ struct qbe_value tuple_qv = mklval(ctx, &tuple_gv);
+ struct qbe_value sz = constl(type->size);
+ enum qbe_instr alloc = alloc_for_align(type->align);
+ pushprei(ctx->current, &tuple_qv, alloc, &sz, NULL);
+
+ gen_expr_at(ctx, binding->initializer, tuple_gv);
+
+ for (const struct binding_unpack *unpack = binding->unpack;
+ unpack; unpack = unpack->next) {
+ if (unpack->object == NULL) {
+ continue;
+ }
+ assert(unpack->object->otype != O_DECL);
+
+ const struct type *type = unpack->object->type;
+ struct gen_value item_gv = {
+ .kind = GV_TEMP,
+ .type = type,
+ .name = gen_name(ctx, "binding.%d"),
+ };
+ struct gen_binding *gb = xcalloc(1, sizeof(struct gen_binding));
+ gb->value = item_gv;
+ gb->object = unpack->object;
+ gb->next = ctx->bindings;
+ ctx->bindings = gb;
+ struct qbe_value item_qv = mklval(ctx, &gb->value);
+ struct qbe_value offs = constl(unpack->offset);
+ pushprei(ctx->current, &item_qv, Q_ADD, &tuple_qv, &offs, NULL);
+ }
+}
+
static struct gen_value
gen_expr_binding(struct gen_context *ctx, const struct expression *expr)
{
for (const struct expression_binding *binding = &expr->binding;
binding; binding = binding->next) {
+ if (binding->unpack) {
+ if (binding->unpack->object->otype == O_DECL) {
+ gen_expr_binding_unpack_static(ctx, binding);
+ } else {
+ gen_expr_binding_unpack(ctx, binding);
+ }
+ continue;
+ }
+
if (binding->object->otype == O_DECL) {
// static binding
struct declaration decl = {
diff --git a/src/parse.c b/src/parse.c
@@ -2012,6 +2012,44 @@ parse_match_expression(struct lexer *lexer)
return exp;
}
+static void
+parse_binding_unpack(struct lexer *lexer, struct ast_binding_unpack **next)
+{
+ struct token tok = {0};
+
+ bool more = true;
+ while (more) {
+ char *name;
+
+ switch (lex(lexer, &tok)) {
+ case T_NAME:
+ name = tok.name;
+ break;
+ case T_UNDERSCORE:
+ name = NULL;
+ break;
+ default:
+ synassert(false, &tok, T_NAME, T_UNDERSCORE, T_EOF);
+ }
+
+ struct ast_binding_unpack *new = xcalloc(1, sizeof *new);
+ *next = new;
+ next = &new->next;
+
+ new->name = name;
+
+ switch (lex(lexer, &tok)) {
+ case T_COMMA:
+ break;
+ case T_RPAREN:
+ more = false;
+ break;
+ default:
+ synassert(false, &tok, T_COMMA, T_RPAREN, T_EOF);
+ }
+ }
+}
+
static struct ast_expression *
parse_binding_list(struct lexer *lexer, bool is_static)
{
@@ -2036,8 +2074,16 @@ parse_binding_list(struct lexer *lexer, bool is_static)
bool more = true;
while (more) {
- want(lexer, T_NAME, &tok);
- binding->name = tok.name;
+ switch (lex(lexer, &tok)) {
+ case T_NAME:
+ binding->name = tok.name;
+ break;
+ case T_LPAREN:
+ parse_binding_unpack(lexer, &binding->unpack);
+ break;
+ default:
+ synassert(false, &tok, T_NAME, T_LPAREN, T_EOF);
+ }
binding->initializer = mkexpr(&lexer->loc);
binding->flags = flags;
binding->is_static = is_static;
diff --git a/tests/21-tuples.ha b/tests/21-tuples.ha
@@ -1,3 +1,5 @@
+use rt;
+
fn storage() void = {
let x: (int, size) = (42, 1337);
assert(size((int, size)) == size(size) * 2);
@@ -16,7 +18,6 @@ fn indexing() void = {
};
fn func(in: (int, size)) (int, size) = (in.0 + 1, in.1 + 1);
-
fn eval_expr_access() void = {
static assert((42, 0).0 == 42 && (42, 0).1 == 0);
};
@@ -30,6 +31,110 @@ fn funcs() void = {
assert(x.0 == 42 && x.1 == 1337);
};
+fn unpacking_static() int = {
+ static let (a, b) = (0, 0);
+ a += 1;
+ b += 1;
+ return a;
+};
+
+fn unpacking_demo() (int, int) = {
+ return (10, 20);
+};
+
+fn unpacking_eval() (int, int) = {
+ static let i = 0;
+ const res = (10 + i, 20 + i);
+ i += 1;
+ return res;
+};
+
+let unpacking_global: int = 0i;
+
+fn unpacking_addone() int = {
+ unpacking_global += 1;
+ return unpacking_global;
+};
+
+fn unpacking() void = {
+ const (a, b, c) = (42, 8, 12);
+ assert(a == 42);
+ assert(b == 8);
+ assert(c == 12);
+
+ const (a, b): (i64, u64) = (2i, 4z);
+ assert(a == 2i64);
+ assert(b == 4u64);
+
+ const (a, b, c): (i64, str, f64) = (2i, "hello", 1.0);
+ assert(a == 2i64);
+ assert(b == "hello");
+ assert(c == 1.0f64);
+
+ let (a, b): (i64, u64) = (1i, 3z);
+ a += 1;
+ b += 1;
+ assert(a == 2i64);
+ assert(b == 4u64);
+
+ const (_, b, c) = (1, 2, 3);
+ assert(b == 2);
+ assert(c == 3);
+
+ const (a, _, c) = (1, 2, 3);
+ assert(a == 1);
+ assert(c == 3);
+
+ const (a, b, _) = (1, 2, 3);
+ assert(a == 1);
+ assert(b == 2);
+
+ unpacking_static();
+ unpacking_static();
+ const a = unpacking_static();
+ assert(a == 3);
+
+ const (a, b) = unpacking_demo();
+ assert(a == 10);
+ assert(b == 20);
+
+ const (a, b) = unpacking_eval();
+ assert(a == 10);
+ assert(b == 20);
+
+ let (a, b, _, d) = (unpacking_addone(), unpacking_addone(),
+ unpacking_addone(), unpacking_addone());
+ assert(a == 1 && b == 2 && d == 4);
+
+ assert(rt::compile("
+ export fn main() void = {
+ let (_, _) = (1, 2);
+ };
+ ") != 0);
+ assert(rt::compile("
+ export fn main() void = {
+ let (x, y) = (1, 2, 3);
+ };
+ ") != 0);
+ assert(rt::compile("
+ export fn main() void = {
+ let (x, y) = 5;
+ };
+ ") != 0);
+ assert(rt::compile("
+ fn getval() int = 5;
+ export fn main() void = {
+ static let (a, b) = (2, getval());
+ };
+ ") != 0);
+ assert(rt::compile("
+ fn getval() int = 5;
+ export fn main() void = {
+ static let (a, _) = (2, getval());
+ };
+ ") != 0);
+};
+
// Regression tests for miscellaneous compiler bugs
fn regression() void = {
let a: (((int | void), int) | void) = (void, 0);
@@ -41,5 +146,6 @@ export fn main() void = {
funcs();
eval_expr_tuple();
eval_expr_access();
+ unpacking();
regression();
};