commit c5bc81947bbf5f243754040dfab049e282739477
parent 81fd59f11f174708cdac9150fa3ff50e462f2e4b
Author: Drew DeVault <sir@cmpwn.com>
Date: Wed, 15 Sep 2021 09:41:55 +0200
all: rework switch grammar & semantics
This changes the syntax as follows:
let y: int = switch (x) {
case 0 =>
yield x + 1;
case 1 =>
yield x + 2;
case 10, 11, 12 =>
yield x + 10;
case =>
void;
yield x;
};
Diffstat:
10 files changed, 249 insertions(+), 132 deletions(-)
diff --git a/include/ast.h b/include/ast.h
@@ -249,7 +249,7 @@ struct ast_expression_compound {
struct ast_match_case {
char *name; // May be null
struct ast_type *type;
- struct ast_expression *value;
+ struct ast_expression_list exprs;
struct ast_match_case *next;
};
@@ -288,7 +288,7 @@ struct ast_case_option {
struct ast_switch_case {
struct ast_case_option *options; // NULL for *
- struct ast_expression *value;
+ struct ast_expression_list exprs;
struct ast_switch_case *next;
};
diff --git a/include/expr.h b/include/expr.h
@@ -296,7 +296,7 @@ struct case_option {
};
struct switch_case {
- struct case_option *options; // NULL for *
+ struct case_option *options; // NULL for default case
struct expression *value;
struct switch_case *next;
};
diff --git a/include/lex.h b/include/lex.h
@@ -20,6 +20,7 @@ enum lexical_token {
T_ASSERT,
T_BOOL,
T_BREAK,
+ T_CASE,
T_CHAR,
T_CONST,
T_CONTINUE,
@@ -71,11 +72,11 @@ enum lexical_token {
T_LAST_KEYWORD = T_YIELD,
// Operators
+ T_ARROW,
T_BANDEQ,
T_BAND,
T_BNOT,
T_BOR,
- T_CASE,
T_COLON,
T_COMMA,
T_DIV,
diff --git a/src/check.c b/src/check.c
@@ -1871,7 +1871,16 @@ check_expr_match(struct context *ctx,
_case->value = xcalloc(1, sizeof(struct expression));
_case->type = ctype;
- check_expression(ctx, acase->value, _case->value, hint);
+
+ // Lower to compound
+ // TODO: This should probably be done in a more first-class way
+ struct ast_expression compound = {
+ .type = EXPR_COMPOUND,
+ .compound = {
+ .list = acase->exprs,
+ },
+ };
+ check_expression(ctx, &compound, _case->value, hint);
if (acase->name) {
scope_pop(&ctx->scope);
@@ -1915,7 +1924,7 @@ check_expr_match(struct context *ctx,
while (_case) {
if (!_case->value->terminates && !type_is_assignable(
expr->result, _case->value->result)) {
- error(ctx, acase->value->loc, expr,
+ error(ctx, acase->exprs.expr->loc, expr,
"Match case is not assignable to result type");
return;
}
@@ -2447,7 +2456,17 @@ check_expr_switch(struct context *ctx,
}
_case->value = xcalloc(1, sizeof(struct expression));
- check_expression(ctx, acase->value, _case->value, hint);
+
+ // Lower to compound
+ // TODO: This should probably be done in a more first-class way
+ struct ast_expression compound = {
+ .type = EXPR_COMPOUND,
+ .compound = {
+ .list = acase->exprs,
+ },
+ };
+ check_expression(ctx, &compound, _case->value, hint);
+
if (_case->value->terminates) {
continue;
}
@@ -2486,7 +2505,7 @@ check_expr_switch(struct context *ctx,
while (_case) {
if (!_case->value->terminates && !type_is_assignable(
expr->result, _case->value->result)) {
- error(ctx, acase->value->loc, expr,
+ error(ctx, acase->exprs.expr->loc, expr,
"Switch case is not assignable to result type");
return;
}
@@ -3211,8 +3230,11 @@ expr_is_specified(struct context *ctx, const struct ast_expression *aexpr)
if (!type_is_specified(ctx, mcase->type)) {
return false;
}
- if (!expr_is_specified(ctx, mcase->value)) {
- return false;
+ for (const struct ast_expression_list *list = &mcase->exprs;
+ list; list = list->next) {
+ if (!expr_is_specified(ctx, list->expr)) {
+ return false;
+ }
}
}
return expr_is_specified(ctx, aexpr->match.value);
@@ -3259,8 +3281,11 @@ expr_is_specified(struct context *ctx, const struct ast_expression *aexpr)
return false;
}
}
- if (!expr_is_specified(ctx, scase->value)) {
- return false;
+ for (const struct ast_expression_list *list = &scase->exprs;
+ list; list = list->next) {
+ if (!expr_is_specified(ctx, list->expr)) {
+ return false;
+ }
}
}
return expr_is_specified(ctx, aexpr->_switch.value);
diff --git a/src/lex.c b/src/lex.c
@@ -28,11 +28,12 @@ static const char *tokens[] = {
[T_ASSERT] = "assert",
[T_BOOL] = "bool",
[T_BREAK] = "break",
+ [T_CASE] = "case",
[T_CHAR] = "char",
[T_CONST] = "const",
[T_CONTINUE] = "continue",
- [T_DEF] = "def",
[T_DEFER] = "defer",
+ [T_DEF] = "def",
[T_DELETE] = "delete",
[T_ELSE] = "else",
[T_ENUM] = "enum",
@@ -78,11 +79,11 @@ static const char *tokens[] = {
[T_YIELD] = "yield",
// Operators
+ [T_ARROW] = "=>",
[T_BANDEQ] = "&=",
[T_BAND] = "&",
[T_BNOT] = "~",
[T_BOR] = "|",
- [T_CASE] = "=>",
[T_COLON] = ":",
[T_COMMA] = ",",
[T_DIV] = "/",
@@ -855,7 +856,7 @@ lex2(struct lexer *lexer, struct token *out, uint32_t c)
out->token = T_LEQUAL;
break;
case '>':
- out->token = T_CASE;
+ out->token = T_ARROW;
break;
default:
push(lexer, c, false);
@@ -946,6 +947,9 @@ _lex(struct lexer *lexer, struct token *out)
case '?':
out->token = T_QUESTION;
break;
+ case '`':
+ out->token = T_CASE;
+ break;
default:
out->token = T_ERROR;
break;
diff --git a/src/parse.c b/src/parse.c
@@ -1721,9 +1721,8 @@ parse_case_options(struct lexer *lexer)
{
struct token tok = {0};
switch (lex(lexer, &tok)) {
- case T_TIMES:
- want(lexer, T_CASE, &tok);
- return NULL;
+ case T_ARROW:
+ return NULL; // Default case
default:
unlex(lexer, &tok);
break;
@@ -1738,7 +1737,7 @@ parse_case_options(struct lexer *lexer)
switch (lex(lexer, &tok)) {
case T_COMMA:
switch (lex(lexer, &tok)) {
- case T_CASE:
+ case T_ARROW:
more = false;
break;
default:
@@ -1749,11 +1748,11 @@ parse_case_options(struct lexer *lexer)
break;
}
break;
- case T_CASE:
+ case T_ARROW:
more = false;
break;
default:
- synassert(false, &tok, T_CASE, T_COMMA, T_EOF);
+ synassert(false, &tok, T_COMMA, T_ARROW, T_EOF);
break;
}
}
@@ -1771,7 +1770,6 @@ parse_switch_expression(struct lexer *lexer)
want(lexer, T_LPAREN, &tok);
exp->_switch.value = parse_expression(lexer);
want(lexer, T_RPAREN, &tok);
-
want(lexer, T_LBRACE, &tok);
bool more = true;
@@ -1779,25 +1777,42 @@ parse_switch_expression(struct lexer *lexer)
while (more) {
struct ast_switch_case *_case =
*next_case = xcalloc(1, sizeof(struct ast_switch_case));
+ want(lexer, T_CASE, &tok);
_case->options = parse_case_options(lexer);
- _case->value = parse_expression(lexer);
- switch (lex(lexer, &tok)) {
- case T_COMMA:
+ bool exprs = true;
+ struct ast_expression_list *cur = &_case->exprs;
+ struct ast_expression_list **next = &cur->next;
+ while (exprs) {
+ cur->expr = parse_expression(lexer);
+ want(lexer, T_SEMICOLON, &tok);
+
switch (lex(lexer, &tok)) {
+ case T_CASE:
case T_RBRACE:
- more = false;
+ exprs = false;
break;
default:
- unlex(lexer, &tok);
break;
}
+ unlex(lexer, &tok);
+
+ if (exprs) {
+ *next = xcalloc(1, sizeof(struct ast_expression_list));
+ cur = *next;
+ next = &cur->next;
+ }
+ }
+
+ switch (lex(lexer, &tok)) {
+ case T_CASE:
+ unlex(lexer, &tok);
break;
case T_RBRACE:
more = false;
break;
default:
- synassert(false, &tok, T_COMMA, T_RBRACE, T_EOF);
+ synassert(false, &tok, T_CASE, T_RBRACE, T_EOF);
}
next_case = &_case->next;
@@ -1823,6 +1838,7 @@ parse_match_expression(struct lexer *lexer)
while (more) {
struct ast_match_case *_case =
*next_case = xcalloc(1, sizeof(struct ast_match_case));
+ want(lexer, T_CASE, &tok);
struct token tok2 = {0};
struct identifier ident = {0};
@@ -1842,7 +1858,7 @@ parse_match_expression(struct lexer *lexer)
_case->type->storage = STORAGE_ALIAS;
_case->type->alias = ident;
break;
- case T_CASE:
+ case T_ARROW:
unlex(lexer, &tok2);
_case->type = mktype(&tok.loc);
_case->type->storage = STORAGE_ALIAS;
@@ -1850,24 +1866,13 @@ parse_match_expression(struct lexer *lexer)
break;
default:
synassert(false, &tok, T_COLON,
- T_DOUBLE_COLON, T_CASE, T_EOF);
+ T_DOUBLE_COLON, T_ARROW, T_EOF);
break;
}
break;
- case T_TIMES:
- switch (lex(lexer, &tok2)) {
- case T_CASE: // Default case
- unlex(lexer, &tok2);
- break;
- default:
- unlex(lexer, &tok2);
- _case->type = parse_type(lexer);
- struct ast_type *ptr = mktype(&tok.loc);
- ptr->storage = STORAGE_POINTER;
- ptr->pointer.referent = _case->type;
- _case->type = ptr;
- break;
- }
+ case T_ARROW:
+ // Default case
+ unlex(lexer, &tok);
break;
case T_NULL:
type = mktype(&tok.loc);
@@ -1880,25 +1885,41 @@ parse_match_expression(struct lexer *lexer)
break;
}
- want(lexer, T_CASE, &tok);
- _case->value = parse_expression(lexer);
+ want(lexer, T_ARROW, &tok);
+
+ bool exprs = true;
+ struct ast_expression_list *cur = &_case->exprs;
+ struct ast_expression_list **next = &cur->next;
+ while (exprs) {
+ cur->expr = parse_expression(lexer);
+ want(lexer, T_SEMICOLON, &tok);
- switch (lex(lexer, &tok)) {
- case T_COMMA:
switch (lex(lexer, &tok)) {
+ case T_CASE:
case T_RBRACE:
- more = false;
+ exprs = false;
break;
default:
- unlex(lexer, &tok);
break;
}
+ unlex(lexer, &tok);
+
+ if (exprs) {
+ *next = xcalloc(1, sizeof(struct ast_expression_list));
+ cur = *next;
+ next = &cur->next;
+ }
+ }
+
+ switch (lex(lexer, &tok)) {
+ case T_CASE:
+ unlex(lexer, &tok);
break;
case T_RBRACE:
more = false;
break;
default:
- synassert(false, &tok, T_COMMA, T_RBRACE, T_EOF);
+ synassert(false, &tok, T_CASE, T_RBRACE, T_EOF);
}
next_case = &_case->next;
diff --git a/tests/14-switch.ha b/tests/14-switch.ha
@@ -3,12 +3,14 @@ fn basics() void = {
for (let i = 0z; i < len(cases); i += 1) {
let x = cases[i][0];
let y: int = switch (x) {
- 0 => x + 1,
- 1 => x + 2,
- 10, 11, 12 => x + 10,
- * => {
- yield x;
- },
+ case 0 =>
+ yield x + 1;
+ case 1 =>
+ yield x + 2;
+ case 10, 11, 12 =>
+ yield x + 10;
+ case =>
+ yield x;
};
assert(y == cases[i][1]);
};
@@ -17,9 +19,12 @@ fn basics() void = {
fn termination() void = {
let x = 42;
let y: int = switch (x) {
- 42 => 1337,
- 24 => abort(),
- * => abort(),
+ case 42 =>
+ yield 1337;
+ case 24 =>
+ abort();
+ case =>
+ abort();
};
assert(y == 1337);
};
@@ -27,23 +32,31 @@ fn termination() void = {
fn tagged_result() void = {
let x = 42;
let y: (int | uint) = switch (x) {
- 42 => 1337i,
- * => 1337u,
+ case 42 =>
+ yield 1337i;
+ case =>
+ yield 1337u;
};
assert(y is int);
x = 24;
y = switch (x) {
- 42 => 1337i,
- * => 1337u,
+ case 42 =>
+ yield 1337i;
+ case =>
+ yield 1337u;
};
assert(y is uint);
};
fn str_switching() void = {
let result = switch ("hare") {
- "world" => abort(),
- "hare" => true,
+ case "world" =>
+ abort();
+ case "hare" =>
+ yield true;
+ case =>
+ abort();
};
assert(result == true);
};
diff --git a/tests/18-match.ha b/tests/18-match.ha
@@ -3,9 +3,12 @@ fn tagged() void = {
let expected: [_]size = [1, 2, 5];
for (let i = 0z; i < len(cases); i += 1) {
let y: size = match (cases[i]) {
- int => 1,
- uint => 2,
- s: str => len(s),
+ case int =>
+ yield 1;
+ case uint =>
+ yield 2;
+ case s: str =>
+ yield len(s);
};
assert(y == expected[i]);
};
@@ -15,20 +18,25 @@ fn termination() void = {
let x: (int | uint | str) = 1337i;
for (true) {
let y: int = match (x) {
- int => 42,
- uint => abort(),
- str => break,
+ case int =>
+ yield 42;
+ case uint =>
+ abort();
+ case str =>
+ break;
};
assert(y == 42);
x = "hi";
};
};
-fn default() void = {
+fn _default() void = {
let x: (int | uint | str) = 1337u;
let y: int = match (x) {
- int => 42,
- * => 24,
+ case int =>
+ yield 42;
+ case =>
+ yield 24;
};
assert(y == 24);
};
@@ -37,15 +45,19 @@ fn pointer() void = {
let x = 42;
let y: nullable *int = &x;
let z: int = match (y) {
- y: *int => *y,
- null => abort(),
+ case y: *int =>
+ yield *y;
+ case null =>
+ abort();
};
assert(z == 42);
y = null;
z = match(y) {
- *int => abort(),
- null => 1337,
+ case *int =>
+ abort();
+ case null =>
+ yield 1337;
};
assert(z == 1337);
};
@@ -59,8 +71,10 @@ fn alias() void = {
let expected = [42, 24];
for (let i = 0z; i < len(cases); i += 1) {
let y: int = match (cases[i]) {
- foo => 42,
- bar => 24,
+ case foo =>
+ yield 42;
+ case bar =>
+ yield 24;
};
assert(y == expected[i]);
};
@@ -69,15 +83,19 @@ fn alias() void = {
fn tagged_result() void = {
let x: (int | uint) = 42i;
let y: (int | uint) = match (x) {
- x: int => x,
- x: uint => x,
+ case x: int =>
+ yield x;
+ case x: uint =>
+ yield x;
};
assert(y is int);
x = 42u;
y = match (x) {
- x: int => x,
- x: uint => x,
+ case x: int =>
+ yield x;
+ case x: uint =>
+ yield x;
};
assert(y is uint);
};
@@ -86,8 +104,10 @@ fn implicit_cast() void = {
let x: foobar = foo;
let y: nullable *int = null;
let a: (int | foobar) = match (y) {
- null => foo,
- z: *int => *z,
+ case null =>
+ yield foo;
+ case z: *int =>
+ yield *z;
};
assert(a is foobar);
};
@@ -101,53 +121,68 @@ type foobarbaz = (foobar | baz);
fn transitivity() void = {
let x: (foobar | int) = 10;
match (x) {
- i: int => assert(i == 10),
- foo => abort(),
- bar => abort(),
+ case i: int =>
+ assert(i == 10);
+ case foo =>
+ abort();
+ case bar =>
+ abort();
};
x = foo;
let visit = false;
match (x) {
- int => abort(),
- foo => visit = true,
- bar => abort(),
+ case int =>
+ abort();
+ case foo =>
+ visit = true;
+ case bar =>
+ abort();
};
assert(visit);
x = bar;
visit = false;
match (x) {
- int => abort(),
- foo => abort(),
- foobar => visit = true,
+ case int =>
+ abort();
+ case foo =>
+ abort();
+ case foobar =>
+ visit = true;
};
assert(visit);
visit = false;
match (x) {
- z: (foo | bar) => {
- visit = true;
- assert(z is bar);
- },
- int => abort(),
+ case z: (foo | bar) =>
+ visit = true;
+ assert(z is bar);
+ case int =>
+ abort();
};
assert(visit);
let y: foobarbaz = 10;
visit = false;
match (y) {
- baz => visit = true,
- foo => abort(),
- bar => abort(),
+ case baz =>
+ visit = true;
+ case foo =>
+ abort();
+ case bar =>
+ abort();
};
assert(visit);
y = foo;
visit = false;
match (y) {
- baz => abort(),
- foo => visit = true,
- bar => abort(),
+ case baz =>
+ abort();
+ case foo =>
+ visit = true;
+ case bar =>
+ abort();
};
assert(visit);
};
@@ -161,14 +196,16 @@ export fn numeric() void = {
let visit = true;
let x: integer = 1337i;
match (x) {
- s: signed => match (s) {
- i: int => {
- visit = true;
- assert(i == 1337);
- },
- * => abort(),
- },
- u: unsigned => abort(),
+ case s: signed =>
+ match (s) {
+ case i: int =>
+ visit = true;
+ assert(i == 1337);
+ case =>
+ abort();
+ };
+ case u: unsigned =>
+ abort();
};
assert(visit);
};
@@ -179,8 +216,10 @@ type align_8 = (void | int | i64);
export fn alignment_conversion() void = {
let x: align_8 = 1234i;
match (x) {
- y: align_4 => assert(y as int == 1234),
- * => abort(),
+ case y: align_4 =>
+ assert(y as int == 1234);
+ case =>
+ abort();
};
let y: align_4 = 4321i;
x = y: align_8;
@@ -190,7 +229,7 @@ export fn alignment_conversion() void = {
export fn main() void = {
tagged();
termination();
- default();
+ _default();
pointer();
alias();
tagged_result();
diff --git a/tests/26-gen.ha b/tests/26-gen.ha
@@ -23,23 +23,29 @@ export fn main() void = {
let x: (void | int) = 10;
match (x) {
- i: int => assert(i == 10),
- void => abort(),
+ case i: int =>
+ assert(i == 10);
+ case void =>
+ abort();
};
let p = 0;
let p = &p: uintptr: u64: (u64 | void);
let p = match (p) {
- void => abort(),
- p: u64 => p: uintptr: *int,
+ case void =>
+ abort();
+ case p: u64 =>
+ yield p: uintptr: *int;
};
assert(*p == 0);
let thing: int = 0;
let thing = &thing: (*int | int);
let p = match (thing) {
- int => abort(),
- p: *int => p,
+ case int =>
+ abort();
+ case p: *int =>
+ yield p;
};
*p = 0;
};
diff --git a/tests/30-reduction.c b/tests/30-reduction.c
@@ -78,17 +78,25 @@ int main(void) {
"else if (true) null");
test(&ctx, "(nullable *int | void)",
"match (0u8: (u8 | u16 | u32 | u64)) { "
- "u8 => null: *int, "
- "u16 => null: nullable *int, "
- "u32 => null, "
- "u64 => void, "
+ "case u8 => "
+ " yield null: *int; "
+ "case u16 => "
+ " yield null: nullable *int; "
+ "case u32 => "
+ " yield null; "
+ "case u64 => "
+ " yield;"
"}");
test(&ctx, "(nullable *int | void)",
"switch (0) { "
- "42 => null: *int, "
- "69 => null: nullable *int, "
- "1337 => null, "
- "* => void, "
+ "case 42 => "
+ " yield null: *int;"
+ "case 69 => "
+ " yield null: nullable *int;"
+ "case 1337 => "
+ " yield null;"
+ "case => "
+ " yield;"
"};");
// if, match, and switch all use the same code for reduction, so we