commit 53434fd6baca1cd610d00d9e4c65508ba2babb2f
parent 683528ff71e27bd2b643c0ad3b92322034d79bbd
Author: Bor Grošelj Simić <bor.groseljsimic@telemach.net>
Date: Mon, 28 Mar 2022 22:52:49 +0200
allow casting with nested tagged unions
suppose we define two tagged types and a variable:
type foo = (int | void);
type bar = (size | foo);
let a: bar = 1337;
Before this change, the best way to get the integer value from the
nested tagged union was to use a match:
match (a) {
case let i: int =>
...
case =>
abort();
};
This change makes casting a nested tagged union to types that are
members of its member types possible, allowing us to rewrite the match
expression used above to:
let i = a as int;
Signed-off-by: Bor Grošelj Simić <bgs@turminal.net>
Diffstat:
2 files changed, 69 insertions(+), 39 deletions(-)
diff --git a/src/gen.c b/src/gen.c
@@ -1111,6 +1111,11 @@ static struct gen_value gen_subset_match_tests(struct gen_context *ctx,
struct qbe_value bmatch, struct qbe_value bnext,
struct qbe_value tag, const struct type *type);
+static struct gen_value gen_nested_match_tests(struct gen_context *ctx,
+ struct gen_value object, struct qbe_value bmatch,
+ struct qbe_value bnext, struct qbe_value tag,
+ const struct type *type);
+
static struct gen_value
gen_expr_type_test(struct gen_context *ctx, const struct expression *expr)
{
@@ -1125,44 +1130,38 @@ gen_expr_type_test(struct gen_context *ctx, const struct expression *expr)
enum qbe_instr load = load_for_type(ctx, &builtin_type_uint);
struct gen_value result = {0};
+ struct qbe_statement endl;
+ struct qbe_value bend = mklabel(ctx, &endl, ".%d");
pushi(ctx->current, &tag, load, &qval, NULL);
if (tagged_select_subtype(from, secondary) != NULL) {
- result = mktemp(ctx, &builtin_type_bool, ".%d");
- struct qbe_value qr = mkqval(ctx, &result);
- struct qbe_value expected = constl(secondary->id);
- pushi(ctx->current, &qr, Q_CEQW, &tag, &expected, NULL);
+ result = gen_nested_match_tests(ctx, val, bend, bend, tag, secondary);
} else if (tagged_subset_compat(from, secondary)) {
- struct qbe_statement endl;
- struct qbe_value bend = mklabel(ctx, &endl, ".%d");
result = gen_subset_match_tests(ctx, bend, bend, tag,
type_dealias(secondary));
- push(&ctx->current->body, &endl);
} else {
abort();
}
+ push(&ctx->current->body, &endl);
return result;
}
static void
gen_type_assertion(struct gen_context *ctx,
const struct expression *expr,
- struct qbe_value base)
+ struct gen_value base)
{
const struct type *want = expr->result;
struct qbe_value tag = mkqtmp(ctx,
qtype_lookup(ctx, &builtin_type_uint, false), ".%d");
enum qbe_instr load = load_for_type(ctx, &builtin_type_uint);
- pushi(ctx->current, &tag, load, &base, NULL);
+ struct qbe_value qbase = mkqval(ctx, &base);
+ pushi(ctx->current, &tag, load, &qbase, NULL);
struct qbe_statement failedl, passedl;
struct qbe_value bfailed = mklabel(ctx, &failedl, "failed.%d");
struct qbe_value bpassed = mklabel(ctx, &passedl, "passed.%d");
if (tagged_select_subtype(expr->cast.value->result, want)) {
- struct gen_value result = mktemp(ctx, &builtin_type_bool, ".%d");
- struct qbe_value expected = constl(want->id);
- struct qbe_value qr = mkqval(ctx, &result);
- pushi(ctx->current, &qr, Q_CEQW, &tag, &expected, NULL);
- pushi(ctx->current, NULL, Q_JNZ, &qr, &bpassed, &bfailed, NULL);
+ gen_nested_match_tests(ctx, base, bpassed, bfailed, tag, want);
} else if (tagged_subset_compat(expr->cast.value->result, want)) {
gen_subset_match_tests(ctx, bpassed, bfailed, tag,
type_dealias(want));
@@ -1232,11 +1231,9 @@ gen_expr_cast_tagged_at(struct gen_context *ctx,
// type 'from' as if it were of type 'to'
struct gen_value out2 = out;
out2.type = from;
- struct gen_value val = gen_expr(ctx, expr->cast.value);
- gen_store(ctx, out2, val);
- struct qbe_value qout = mkqval(ctx, &out2);
+ gen_expr_at(ctx, expr->cast.value, out2);
if (expr->cast.kind == C_ASSERTION) {
- gen_type_assertion(ctx, expr, qout);
+ gen_type_assertion(ctx, expr, out2);
}
} else if (!subtype) {
// Case 2: like case 1, but with an alignment mismatch; more
@@ -1244,7 +1241,7 @@ gen_expr_cast_tagged_at(struct gen_context *ctx,
struct gen_value value = gen_expr(ctx, expr->cast.value);
struct qbe_value qval = mkqval(ctx, &value);
if (expr->cast.kind == C_ASSERTION) {
- gen_type_assertion(ctx, expr, qval);
+ gen_type_assertion(ctx, expr, value);
}
struct qbe_value qout = mkqval(ctx, &out);
struct qbe_value tag = mkqtmp(ctx,
@@ -1273,7 +1270,6 @@ gen_expr_cast_tagged_at(struct gen_context *ctx,
gen_copy_aligned(ctx, iout, ival);
} else {
// Case 3: from is a member of to
- assert(expr->cast.kind == C_CAST);
assert(subtype == from); // Lowered by check
struct qbe_value qout = mkqval(ctx, &out);
struct qbe_value id = constw(subtype->id);
@@ -1405,6 +1401,9 @@ gen_expr_cast_at(struct gen_context *ctx,
}
}
+static struct qbe_value nested_tagged_offset(const struct type *tu,
+ const struct type *targed);
+
static struct gen_value
gen_expr_cast(struct gen_context *ctx, const struct expression *expr)
{
@@ -1499,16 +1498,16 @@ gen_expr_cast(struct gen_context *ctx, const struct expression *expr)
default: break;
}
-
// Special case: tagged => non-tagged
if (type_dealias(from)->storage == STORAGE_TAGGED) {
struct gen_value value = gen_expr(ctx, expr->cast.value);
struct qbe_value base = mkcopy(ctx, &value, ".%d");
if (expr->cast.kind == C_ASSERTION) {
- gen_type_assertion(ctx, expr, base);
+ gen_type_assertion(ctx, expr, value);
}
- struct qbe_value align = constl(from->align);
+ struct qbe_value align = nested_tagged_offset(
+ expr->cast.value->result, expr->cast.secondary);
pushi(ctx->current, &base, Q_ADD, &base, &align, NULL);
struct gen_value storage = (struct gen_value){
.kind = GV_TEMP,
@@ -2167,11 +2166,35 @@ enum match_compat {
COMPAT_MISALIGNED,
};
-static void
+static struct qbe_value
+nested_tagged_offset(const struct type *tu, const struct type *target)
+{
+ // This function calculates the offset of a member in a nested tagged union
+ //
+ // type foo = (int | void);
+ // type bar = (size | foo);
+ //
+ // Offset of the foo field from the start of bar is 8 and offset of int
+ // inside foo is 4, so the offset of the int from the start of bar is 12
+ const struct type *test = tu;
+ struct qbe_value offset = constl(tu->align);
+ do {
+ test = tagged_select_subtype(tu, target);
+ if (!test) {
+ break;
+ }
+ if (test->id != target->id && type_dealias(test)->id != target->id) {
+ offset.lval += test->align;
+ }
+ tu = test;
+ } while (test->id != target->id && type_dealias(test)->id != target->id);
+ return offset;
+}
+
+static struct gen_value
gen_nested_match_tests(struct gen_context *ctx, struct gen_value object,
struct qbe_value bmatch, struct qbe_value bnext,
- struct qbe_value tag, const struct match_case *_case,
- struct qbe_value *offset)
+ struct qbe_value tag, const struct type *type)
{
// This function handles the case where we're matching against a type
// which is a member of the tagged union, or an inner tagged union.
@@ -2191,38 +2214,36 @@ gen_nested_match_tests(struct gen_context *ctx, struct gen_value object,
// tag of the foo object for int.
struct qbe_value *subtag = &tag;
struct qbe_value subval = mkcopy(ctx, &object, "subval.%d");
- struct qbe_value match = mkqtmp(ctx, &qbe_word, ".%d");
+ struct gen_value match = mktemp(ctx, &builtin_type_bool, ".%d");
+ struct qbe_value qmatch = mkqval(ctx, &match);
struct qbe_value temp = mkqtmp(ctx, &qbe_word, ".%d");
const struct type *subtype = object.type;
- const struct type *test = _case->type;
- *offset = constl(subtype->align);
+ const struct type *test = type;
do {
struct qbe_statement lsubtype;
struct qbe_value bsubtype = mklabel(ctx, &lsubtype, "subtype.%d");
- test = tagged_select_subtype(subtype, _case->type);
+ test = tagged_select_subtype(subtype, type);
if (!test) {
break;
}
struct qbe_value id = constw(test->id);
- pushi(ctx->current, &match, Q_CEQW, subtag, &id, NULL);
- pushi(ctx->current, NULL, Q_JNZ, &match, &bsubtype, &bnext, NULL);
+ pushi(ctx->current, &qmatch, Q_CEQW, subtag, &id, NULL);
+ pushi(ctx->current, NULL, Q_JNZ, &qmatch, &bsubtype, &bnext, NULL);
push(&ctx->current->body, &lsubtype);
- if (test->id != _case->type->id
- && type_dealias(test)-> id != _case->type->id) {
+ if (test->id != type->id && type_dealias(test)-> id != type->id) {
struct qbe_value offs = constl(subtype->align);
pushi(ctx->current, &subval, Q_ADD, &subval, &offs, NULL);
pushi(ctx->current, &temp, Q_LOADUW, &subval, NULL);
- offset->lval += test->align;
subtag = &temp;
}
subtype = test;
- } while (test->id != _case->type->id
- && type_dealias(test)->id != _case->type->id);
+ } while (test->id != type->id && type_dealias(test)->id != type->id);
pushi(ctx->current, NULL, Q_JMP, &bmatch, NULL);
+ return match;
}
static struct gen_value
@@ -2286,7 +2307,6 @@ gen_match_with_tagged(struct gen_context *ctx,
continue;
}
- struct qbe_value offset;
struct qbe_statement lmatch, lnext;
struct qbe_value bmatch = mklabel(ctx, &lmatch, "matches.%d");
struct qbe_value bnext = mklabel(ctx, &lnext, "next.%d");
@@ -2295,7 +2315,7 @@ gen_match_with_tagged(struct gen_context *ctx,
enum match_compat compat = COMPAT_SUBTYPE;
if (subtype) {
gen_nested_match_tests(ctx, object,
- bmatch, bnext, tag, _case, &offset);
+ bmatch, bnext, tag, _case->type);
} else {
assert(type_dealias(_case->type)->storage == STORAGE_TAGGED);
assert(tagged_subset_compat(objtype, _case->type));
@@ -2334,8 +2354,10 @@ gen_match_with_tagged(struct gen_context *ctx,
.name = ptr.name,
};
struct gen_value load;
+ struct qbe_value offset;
switch (compat) {
case COMPAT_SUBTYPE:
+ offset = nested_tagged_offset(object.type, _case->type);
pushi(ctx->current, &ptr, Q_ADD, &qobject, &offset, NULL);
load = gen_load(ctx, src);
gen_store(ctx, gb->value, load);
diff --git a/tests/13-tagged.ha b/tests/13-tagged.ha
@@ -167,12 +167,20 @@ fn subsetcast() void = {
assert(x is (int | void) && (x is int) && !(x is void));
};
+type foo = (int | void);
+type bar = (size | foo);
+
fn castout() void = {
let x: (int | void) = 1337;
assert(x: int == 1337);
assert(x as int == 1337);
assert(x is int);
// XXX: We can probably expand this
+
+ let a: bar = 42i;
+ assert(a as int == 42);
+ assert(a: int == 42);
+ assert(a is int);
};
fn assertions() void = {