How CaseIterable Works Internally in Swift
CaseIterable
is one of my favorite features in Swift 4.2. Despite being a simple protocol, it solves the common problem (that I personally faced many times) of needing access to an array containing all the cases of a certain enum.
If we take a look at how CaseIterable
is implemented in the Standard Library, we can see that the protocol is just what one would expect it to be: a simple definition of an array of cases.
public protocol CaseIterable {
/// A type that can represent a collection of all values of this type.
associatedtype AllCases: Collection where AllCases.Element == Self
/// A collection of all values of this type.
static var allCases: AllCases { get }
}
But this article isn't about the Swift aspect of this protocol. As you probably know, this protocol is special: you don’t need to define and fill the allCases
type - the compiler does it for you.
enum MyEnum: CaseIterable {
case foo
case bar
// Code generated by the compiler
static var allCases: AllCases { // alias for [MyEnum]
return [.foo, .bar]
}
//
}
This behaviour isn’t new - the same concept is applied in many other protocols like RawRepresentable
and Codable
(and now also Equatable
/Hashable
), but I had never really researched how this was done. Since I've been studying compilers lately to be able to fix SwiftShield's edge cases, I took this opportunity to jump into Swift's source code, learn something and show you how it's done.
Retrieving a .swift file's Abstract Syntax Tree
In order to find out how Swift's generated code is generated, we need to know what this generated code actually looks like.
You can reverse engineer the resulting binary, but it would be painfully hard to understand what the assembly means. Another option is to fork the Swift compiler and attach lldb to it, but you would need to know what to breakpoint in the first place - which I have no idea.
Luckily, the Swift compiler in your Xcode's toolchain offers several arguments that allow you to extract human-readable files that represent "processed" versions of Swift source files, and one of these options allow you to retrieve the Abstract Syntax Tree (AST) of a file.
Although the AST is just the contents of your file written as a tree-like structure, the AST returned by the Swift compiler will contain all the optimizations and overall magic done to your file. This allows us to see what an enum with CaseIterable
looks like after compiling.
First, I'll create a basic enum at some file named enum.swift
:
enum MyEnum: CaseIterable {
case foo
case bar
}
Now, to get the AST of that, I'll run swiftc
with the -dump-ast
argument:
swiftc -dump-ast enum.swift
This returns a gigantic tree structure due to all the code generation involved in Swift, but I extracted the part relevant to the declaration of allCases
:
(var_decl implicit "allCases" type='[MyEnum]' interface type='[MyEnum]' access=internal type storage_kind=computed
(accessor_decl implicit 'anonname=0x7fa86f015c28' interface type='(MyEnum.Type) -> () -> [MyEnum]' access=internal type getter_for=allCases
(parameter_list
(parameter "self" interface type='MyEnum.Type'))
(parameter_list)
(brace_stmt
(return_stmt implicit
(array_expr type='[MyEnum]'
(dot_syntax_call_expr implicit type='MyEnum' nothrow
(declref_expr implicit type='(MyEnum.Type) -> MyEnum' decl=moduletest.(file).MyEnum.foo@/Users/bruno.rocha/Desktop/moduletest.swift:2:10 function_ref=double)
(type_expr implicit type='MyEnum.Type' typerepr='<<NULL>>'))
(dot_syntax_call_expr implicit type='MyEnum' nothrow
(declref_expr implicit type='(MyEnum.Type) -> MyEnum' decl=moduletest.(file).MyEnum.bar@/Users/bruno.rocha/Desktop/moduletest.swift:3:10 function_ref=double)
(type_expr implicit type='MyEnum.Type' typerepr='<<NULL>>')))))))
ASTs are very verbose, but the names help us understand what this actually means. We have a declaration of an allCases
property (var_decl
), which (brace_stmt
) returns (return_stmt
) a MyEnum
array (array_expr
) containing .foo
(defined by the dot_syntax_call_expr
of implicit type MyEnum
followed by the declref_expr
reference of MyEnum.foo
) and .bar
(same as before).
Verbosity aside, this is the same as the return [.foo, .bar]
shown above. But where is this code injection happening?
Debugging the Swift Compiler
Since CaseIterable
is a relatively simple protocol, we can likely uncover its internals by searching the open source Swift repository on GitHub. I did this and got only about 2 pages of references - most being unit tests.
One of the results is a reference to the actual thing: A suspicious method named deriveCaseIterable_enum_getter
in a file named DerivedConformanceCaseIterable.cpp that takes a property's body and appends some content to it. Bingo!
But before analyzing what this method does, I'm interested in knowing how the compiler got here in the first place.
By making a fork of the Swift compiler and building it in debug mode, we're able to attach lldb to it, breakpoint this method and call bt
to print its backtrace.
lldb -- /swift-fork/build/Ninja-ReleaseAssert+swift-DebugAssert/swift-macosx-x86_64/bin/swiftc -dump-ast enum.swift
(Note) Since I researched this all by myself, some assumptions might not be fully correct. Feel free to correct me if you know the Swift compiler!
If you take a look at the file in question, you'll find that deriveCaseIterable_enum_getter
isn't called directly. Instead, it gets passed as a reference from another method called deriveCaseIterable
. This means that backtracing won't reveal the information we want - so instead of backtracing it directly, I'll backtrace deriveCaseIterable
itself.
b DerivedConformanceCaseIterable.cpp:82
run
Process 15104 stopped
frame #0: 0x0000000101ca809e DerivedConformance::deriveCaseIterable(this=0x00007ffeefbf6d60, requirement=0x000000010d97d838) at DerivedConformanceCaseIterable.cpp:82
bt
The backtrace goes a long way, but if we take the last seven stack nodes, we end up with:
DerivedConformance::deriveCaseIterable(this=0x00007ffeefbf6d60, requirement=0x000000010d97d838) at DerivedConformanceCaseIterable.cpp:82
TypeChecker::deriveProtocolRequirement(this=0x00007ffeefbf95f0, DC=0x000000010c8d9730, TypeDecl=0x000000010c8d9718, Requirement=0x000000010d97d838) at TypeCheckProtocol.cpp:5137
ConformanceChecker::resolveWitnessViaDerivation(this=0x00007ffeefbf82d0, requirement=0x000000010d97d838) at TypeCheckProtocol.cpp:3081
ConformanceChecker::checkConformance(this=0x00007ffeefbf82d0, Kind=ErrorFixIt) at TypeCheckProtocol.cpp:3665
MultiConformanceChecker::checkIndividualConformance(this=0x00007ffeefbf8058, conformance=0x000000010c8db5e8, issueFixit=true) at TypeCheckProtocol.cpp:1707
MultiConformanceChecker::checkAllConformances(this=0x00007ffeefbf8058) at TypeCheckProtocol.cpp:1328
TypeChecker::checkConformancesInContext(this=0x00007ffeefbf95f0, dc=0x000000010c8d9730, idc=0x000000010c8d9790) at TypeCheckProtocol.cpp:4720
After a quick inspection at each of these symbol's files, we can see that after parsing the file's structure, the compiler starts running a couple workflows to determine if all protocols and conditions are being conformed correctly (take a look at the files from the backtrace to see them yourself!).
At checkConformancesInContext
, the compiler has access to a context (our enum's declaraction). It extracts an array of conformances from it (CaseIterable.allCases
in this case) and calls checkAllConformances
.
checkAllConformances
loops the array of conformances and calls checkIndividualConformance
for each of them. If the requirements are not being satisfied, compilation warnings/errors are dispatched.
checkIndividualConformance
seems to make superficial checks to the conformance, such as checking if it's using a class
protocol outside a class
, or if it's an OBJ-C object trying to conform to a Swift protocol. If the compiler is still incapable of confirming the requirements (because we're literally missing an entire property), checkConformance
is called.
checkConformance
will attempt to validate a protocol through a few procedures. This is where my subpar knowledge of compilers leaves me hanging, but I was able to grasp the meaning of the procedure that matters to us: resolveWitnessViaDerivation
. This is where requirements try to be confirmed by injecting the relevant missing code.
Deriving Protocols
But before resolveWitnessViaDerivation
is called, two important methods that are not in the backtrace are called: getDerivableRequirement
and derivesProtocolRequirement
. You can see them here.
getDerivableRequirement
determines if a certain requirement even supports this kind of code generation in first place. If the name of the requirement matches a requirement in a known protocol, we proceed:
// CaseIterable.allValues
if (name.isSimpleName(ctx.Id_allCases))
return getRequirement(KnownProtocolKind::CaseIterable);
The getRequirement
from the return statement then calls derivesProtocolRequirement
, which will try to match the requirement with the protocol's own set of rules.
For the "CaseIterable
inside enums" feature, the rules are:
case KnownProtocolKind::CaseIterable:
return !enumDecl->hasPotentiallyUnavailableCaseValue()
&& enumDecl->hasOnlyCasesWithoutAssociatedValues();
To be honest, I'm not really sure what an PotentiallyUnavailableCaseValue
refers to (Update: Łukasz Grzywacz discovered that this is checking for cases inside #available
conditions!), but the second condition is something we know: The derivation will only work if your cases don't contain associated values, as the compiler can't possibly know which value you want to be there. That's not the case for MyEnum
, so we're good!
With the derivation being possible, we head back to the backtrace as deriveProtocolRequirement
gets called. The compiler will now attempt to generate the remaining code.
The same object/protocol name matching happens in this method, but in order to actually perform the code generation. For CaseIterable
, this results in deriveCaseIterable
being called.
case KnownProtocolKind::CaseIterable:
return derived.deriveCaseIterable(Requirement);
deriveCaseIterable
performs a few more checks, like seeing if protocol was added in an extension (which is a no-no for derivation). If all goes well, it defines an empty allCases
property and finally calls the method that fills it: the deriveCaseIterable_enum_getter
that we first saw.
auto *returnTy = computeAllCasesType(Nominal); // [MyEnum]
VarDecl *propDecl;
PatternBindingDecl *pbDecl;
std::tie(propDecl, pbDecl) = declareDerivedProperty(TC.Context.Id_allCases, returnTy, returnTy, *isStatic=*/true, /*isFinal=*/true);
// Define the getter.
auto *getterDecl = addGetterToReadOnlyDerivedProperty(TC, propDecl, returnTy);
// Set the getter's body.
getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);
This is the definition of deriveCaseIterable_enum_getter
:
void deriveCaseIterable_enum_getter(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getDeclContext();
auto *parentEnum = parentDC->getSelfEnumDecl();
auto enumTy = parentDC->getDeclaredTypeInContext();
auto &C = parentDC->getASTContext();
SmallVector<Expr *, 8> elExprs;
for (EnumElementDecl *elt : parentEnum->getAllElements()) {
auto *ref = new (C) DeclRefExpr(elt, DeclNameLoc(), /*implicit*/true);
auto *base = TypeExpr::createImplicit(enumTy, C);
auto *apply = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base);
elExprs.push_back(apply);
}
auto *arrayExpr = ArrayExpr::create(C, SourceLoc(), elExprs, {}, SourceLoc());
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), arrayExpr);
auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc());
funcDecl->setBody(body);
}
The interesting thing about this method is that it is a lot less complicated than a non-compiler person like me would expect. Because we're in the middle of the compilation, the compiler has access to a mutable version of our AST seen above and a direct reference to the node that represents the main declaration of our CaseIterable-semi-conformant enum. To add allCases
to it, we just literally write it in AST form and append it to the enum's node.
Although C++ isn't the easiest language to understand, you can see that this is just iterating the enum's cases and creating a return statement as a bunch of expressions that match the expressions of the AST we've seen above. The parameter funcDecl
is the empty body of allClasses
- which was generated by deriveCaseIterable
. After the expression is generated, it gets applied to the body.
Fun time: Adding more properties to CaseIterable
Now that we've figured out how it works, how about adding our own properties to it? I think my fake CaseIterable
would benefit from having a first
property that returned the first defined case.
From the Standard Library's point of view, this is pretty straight-forward as we just need to define a new static var
:
public protocol CaseIterable {
/// A type that can represent a collection of all values of this type.
associatedtype AllCases: Collection where AllCases.Element == Self
/// A collection of all values of this type.
static var allCases: AllCases { get }
/// The first case of this type.
static var first: Self { get }
}
But the users of this protocol don't need to fill the first
property if it's being used on an enum, so I want this property to be derived by the compiler as well.
To do this, I'll first clone the deriveCaseIterable_enum_getter
method that generates the case array and modify it so the expression returns the first case instead of the array:
void deriveCaseIterable_first(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getDeclContext();
auto *parentEnum = parentDC->getSelfEnumDecl();
auto enumTy = parentDC->getDeclaredTypeInContext();
auto &C = parentDC->getASTContext();
EnumElementDecl *elt = parentEnum->getAllElements().front();
auto *ref = new (C) DeclRefExpr(elt, DeclNameLoc(), /*implicit*/true);
auto *base = TypeExpr::createImplicit(enumTy, C);
auto *dotExpr = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base);
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), dotExpr);
auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc());
funcDecl->setBody(body);
}
With that done, we now need to make this method get called. We've seen previously that deriveCaseIterable_enum_getter
gets called by deriveCaseIterable()
- if we inspect the contents of that method, we'll find that it's able to detect the name of the parameter being checked:
ValueDecl *DerivedConformance::deriveCaseIterable(ValueDecl *requirement) {
// Deleted to make stuff shorter: Some pre-checks
if (requirement->getBaseName() != TC.Context.Id_allCases) {
// Deleted to make stuff shorter: Throw compilation error
}
auto *returnTy = computeAllCasesType(Nominal); // Define the [MyEnum] return type
// Deleted to make stuff shorter: Define allCases's getter
declareDerivedProperty(TC.Context.Id_allCases, returnTy, returnTy, *isStatic=*/true, /*isFinal=*/true);
// Deleted to make stuff shorter: Prepare allCases's getter
getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);
}
After searching a bit, I've found that the Id_allCases
property comes from a file named KnownIdentifiers.def. I've edited it to add a new Id_first
property for our feature. I've also added Id_first
to the getDerivableRequirement()
method mentioned above so the compiler knows that this property can be derived.
For this feature to work, we need to keep the old allCases
logic but add an else
block to treat the new first
requirement.
After creating a block for first
, we need to change returnTy
to be MyEnum
instead of [MyEnum]
and have declareDerivedProperty()
use Id_first
as the property name instead of Id_allCases
, and finally, have setBodySynthesizer
use the new method.
To make returnTy
be MyEnum
, I just looked up how computeAllCasesType()
was retrieving the enum's type, which ended up being by calling Nominal->getDeclaredInterfaceType();
.
After some coding, the final method looks like this: (You can see the full version here.)
ValueDecl *DerivedConformance::deriveCaseIterable(ValueDecl *requirement) {
// Deleted to make stuff shorter: Some pre-checks
Type returnTy;
Identifier propertyId;
if (requirement->getBaseName() == TC.Context.Id_allCases) {
returnTy = computeAllCasesType(Nominal);
propertyId = TC.Context.Id_allCases;
} else if (requirement->getBaseName() == TC.Context.Id_first) {
returnTy = Nominal->getDeclaredInterfaceType();
propertyId = TC.Context.Id_first;
} else {
// Deleted to make stuff shorter: Throw compilation error
}
// Deleted to make stuff shorter: Define allCases's getter
declareDerivedProperty(propertyId, returnTy, returnTy, /*isStatic=*/true, /*isFinal=*/true);
if (requirement->getBaseName() == TC.Context.Id_allCases) {
getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);
} else {
getterDecl->setBodySynthesizer(&deriveCaseIterable_first);
}
}
After building the compiler, we can get a CaseIterable
enum's first case without explicitely defining it!
enum MyEnum: CaseIterable {
case foo
case bar
}
print(MyEnum.first) // .foo
Conclusion
Compilers are scary, and the Swift one is no different. I'm still trying to figure out how most things work (If you're a compiler expert, I'm looking for tips on great books and resources!), but one thing that I've said before on my posts is that knowing the internals of a language can really help you write efficient code. I had a lot of fun inspecting this feature and hope it was useful to you in some way.
Follow me on my Twitter - @rockbruno_, and let me know of any suggestions and corrections you want to share.
References and Good reads
The Swift Source CodeUpdate: People were curious on how first
acts if the enum has no cases: It crashes! We can fix it by adding a new rule to derivesProtocolRequirement
that returns false if the requirement is first
and the enum is empty - which would make the compiler return a does not conform to CaseIterable
error in that case.