commit a36b7b00417e90d192550202ff4b209a683a15c4
parent a86ae6299aa0c4df40065d5cca563fdfe83beb10
Author: Bor Grošelj Simić <bor.groseljsimic@telemach.net>
Date: Thu, 24 Mar 2022 01:44:26 +0100
implement type assertions on nullable pointers
Implements: https://todo.sr.ht/~sircmpwn/hare/220
Signed-off-by: Bor Grošelj Simić <bgs@turminal.net>
Diffstat:
M | src/check.c | | | 42 | +++++++++++++++++++++--------------------- |
M | src/gen.c | | | 71 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------- |
M | tests/03-pointers.ha | | | 34 | ++++++++++++++++++++++++++++++++++ |
3 files changed, 115 insertions(+), 32 deletions(-)
diff --git a/src/check.c b/src/check.c
@@ -1212,20 +1212,26 @@ check_expr_cast(struct context *ctx,
type_store_lookup_atype(ctx->store, aexpr->cast.type);
// TODO: Instead of allowing errors on casts to void, we should use a
// different nonterminal
- check_expression(ctx, aexpr->cast.value, value, secondary == &builtin_type_void ? NULL : secondary);
+ check_expression(ctx, aexpr->cast.value, value,
+ secondary == &builtin_type_void ? NULL : secondary);
- if (aexpr->cast.kind == C_ASSERTION || aexpr->cast.kind == C_TEST) {
- const struct type *primary = type_dealias(expr->cast.value->result);
+ const struct type *primary = type_dealias(expr->cast.value->result);
+ switch (aexpr->cast.kind) {
+ case C_ASSERTION:
+ case C_TEST:
+ if (primary->storage == STORAGE_POINTER) {
+ if (!(primary->pointer.flags & PTR_NULLABLE)) {
+ error(ctx, aexpr->cast.value->loc, expr,
+ "Expected a tagged union type or "
+ "a nullable pointer");
+ return;
+ }
+ break;
+ }
if (primary->storage != STORAGE_TAGGED) {
error(ctx, aexpr->cast.value->loc, expr,
- "Expected a tagged union type");
- return;
- }
- if (!type_is_castable(value->result, secondary)) {
- error(ctx, aexpr->cast.type->loc, expr,
- "Invalid cast from %s to %s",
- gen_typename(value->result),
- gen_typename(secondary));
+ "Expected a tagged union type or "
+ "a nullable pointer");
return;
}
bool found = false;
@@ -1238,12 +1244,11 @@ check_expr_cast(struct context *ctx,
}
if (!found) {
error(ctx, aexpr->cast.type->loc, expr,
- "Type is not a valid member of the tagged union type");
+ "Type is not a valid member of "
+ "the tagged union type");
return;
}
- }
-
- switch (aexpr->cast.kind) {
+ break;
case C_CAST:
if (!type_is_castable(secondary, value->result)) {
error(ctx, aexpr->cast.type->loc, expr,
@@ -1252,14 +1257,9 @@ check_expr_cast(struct context *ctx,
gen_typename(secondary));
return;
}
- // Fallthrough
- case C_ASSERTION:
- expr->result = secondary;
- break;
- case C_TEST:
- expr->result = &builtin_type_bool;
break;
}
+ expr->result = aexpr->cast.kind == C_TEST? &builtin_type_bool : secondary;
}
static void
diff --git a/src/gen.c b/src/gen.c
@@ -1388,16 +1388,19 @@ gen_expr_cast_at(struct gen_context *ctx,
static struct gen_value
gen_expr_cast(struct gen_context *ctx, const struct expression *expr)
{
- const struct type *to = expr->result, *from = expr->cast.value->result;
- switch (expr->cast.kind) {
- case C_TEST:
- return gen_expr_type_test(ctx, expr);
- case C_ASSERTION:
- assert(type_dealias(from)->storage == STORAGE_TAGGED);
- assert(tagged_select_subtype(from, to));
- // Fallthrough
- case C_CAST:
- break;
+ const struct type *to = expr->cast.secondary,
+ *from = expr->cast.value->result;
+ if (expr->cast.kind != C_CAST) {
+ bool is_valid_tagged, is_valid_pointer;
+ is_valid_tagged = type_dealias(from)->storage == STORAGE_TAGGED
+ && tagged_select_subtype(from, to);
+ is_valid_pointer = type_dealias(from)->storage == STORAGE_POINTER
+ && (type_dealias(to)->storage == STORAGE_POINTER
+ || type_dealias(to)->storage == STORAGE_NULL);
+ assert(is_valid_tagged || is_valid_pointer);
+ if (expr->cast.kind == C_TEST && is_valid_tagged) {
+ return gen_expr_type_test(ctx, expr);
+ }
}
if (cast_prefers_at(expr)) {
@@ -1415,14 +1418,60 @@ gen_expr_cast(struct gen_context *ctx, const struct expression *expr)
}
// Special cases
+ bool want_null = false;
switch (type_dealias(to)->storage) {
+ case STORAGE_NULL:
+ want_null = true;
+ // fallthrough
case STORAGE_POINTER:
if (type_dealias(from)->storage == STORAGE_SLICE) {
struct gen_value value = gen_expr(ctx, expr->cast.value);
value.type = to;
return gen_load(ctx, value);
}
- break;
+ if (type_dealias(from)->storage != STORAGE_POINTER) {
+ break;
+ }
+
+ struct gen_value val = gen_expr(ctx, expr->cast.value);
+ struct qbe_value qval = mkqval(ctx, &val);
+ if (expr->cast.kind == C_TEST) {
+ struct gen_value out =
+ mktemp(ctx, &builtin_type_bool, ".%d");
+ struct qbe_value qout = mkqval(ctx, &out);
+ struct qbe_value zero = constl(0);
+
+ enum qbe_instr compare = want_null? Q_CEQL : Q_CNEL;
+ pushi(ctx->current, &qout, compare, &qval, &zero, NULL);
+ return out;
+ } else if (expr->cast.kind == C_ASSERTION) {
+ struct qbe_statement failedl, passedl;
+ struct qbe_value bfailed =
+ mklabel(ctx, &failedl, "failed.%d");
+ struct qbe_value bpassed =
+ mklabel(ctx, &passedl, "passed.%d");
+
+ if (want_null) {
+ pushi(ctx->current, NULL, Q_JNZ, &qval,
+ &bfailed, &bpassed, NULL);
+ } else {
+ pushi(ctx->current, NULL, Q_JNZ, &qval,
+ &bpassed, &bfailed, NULL);
+ }
+ push(&ctx->current->body, &failedl);
+ gen_fixed_abort(ctx, expr->loc, ABORT_TYPE_ASSERTION);
+
+ push(&ctx->current->body, &passedl);
+ if (want_null) {
+ return (struct gen_value){
+ .kind = GV_CONST,
+ .type = &builtin_type_null,
+ .lval = 0,
+ };
+ }
+ }
+ val.type = to;
+ return val;
case STORAGE_VOID:
gen_expr(ctx, expr->cast.value); // Side-effects
return gv_void;
diff --git a/tests/03-pointers.ha b/tests/03-pointers.ha
@@ -22,6 +22,31 @@ fn _nullable() void = {
) != 0);
};
+fn casts() void = {
+ let a: *uint = &4u;
+ let b = a: *void;
+ let c = b: *uint;
+ assert(a == c && *c == 4);
+
+ let a: nullable *uint = &7u;
+ let b = a: *uint;
+ assert(b == a && *b == 7);
+
+ let a: nullable *uint = &10u;
+ let b = a as *uint;
+ assert(b == a && *b == 10);
+
+ let a: nullable *int = &4;
+ assert(a is *int);
+
+ let a: nullable *int = null;
+ assert(a is null);
+ assert((a as null): nullable *void == null);
+
+ let a: nullable *int = &4;
+ a as nullable *int;
+};
+
fn reject() void = {
assert(rt::compile("
type s = null;
@@ -65,10 +90,19 @@ fn reject() void = {
let a = null;
};
") != 0);
+
+ // type assertions on non-nullable pointers are prohibited
+ assert(rt::compile("
+ fn test() void = {
+ let a: *int = &4;
+ assert(a as *int);
+ };
+ ") != 0);
};
export fn main() void = {
basics();
_nullable();
+ casts();
reject();
};