How CaseIterable Works Internally in Swift

How CaseIterable Works Internally in Swift

CaseIterable is one of my favorite features in Swift 4.2. Despite being a simple protocol, it solves the common problem (that I personally faced many times) of needing access to an array containing all the cases of a certain enum.

If we take a look at how CaseIterable is implemented in the Standard Library, we can see that the protocol is just what one would expect it to be: a simple definition of an array of cases.

public protocol CaseIterable {
    /// A type that can represent a collection of all values of this type.
    associatedtype AllCases: Collection where AllCases.Element == Self

    /// A collection of all values of this type.
    static var allCases: AllCases { get }
}

But this article isn't about the Swift aspect of this protocol. As you probably know, this protocol is special: you don’t need to define and fill the allCases type - the compiler does it for you.

enum MyEnum: CaseIterable {
    case foo
    case bar

 // Code generated by the compiler
    static var allCases: AllCases { // alias for [MyEnum] 
        return [.foo, .bar]
    }
 // 
}

This behaviour isn’t new - the same concept is applied in many other protocols like RawRepresentable and Codable (and now also Equatable/Hashable), but I had never really researched how this was done. Since I've been studying compilers lately to be able to fix SwiftShield's edge cases, I took this opportunity to jump into Swift's source code, learn something and show you how it's done.

Retrieving a .swift file's Abstract Syntax Tree

In order to find out how Swift's generated code is generated, we need to know what this generated code actually looks like.

You can reverse engineer the resulting binary, but it would be painfully hard to understand what the assembly means. Another option is to fork the Swift compiler and attach lldb to it, but you would need to know what to breakpoint in the first place - which I have no idea.

Luckily, the Swift compiler in your Xcode's toolchain offers several arguments that allow you to extract human-readable files that represent "processed" versions of Swift source files, and one of these options allow you to retrieve the Abstract Syntax Tree (AST) of a file.

Although the AST is just the contents of your file written as a tree-like structure, the AST returned by the Swift compiler will contain all the optimizations and overall magic done to your file. This allows us to see what an enum with CaseIterable looks like after compiling.

First, I'll create a basic enum at some file named enum.swift:

enum MyEnum: CaseIterable {
    case foo
    case bar
}

Now, to get the AST of that, I'll run swiftc with the -dump-ast argument:

swiftc -dump-ast enum.swift

This returns a gigantic tree structure due to all the code generation involved in Swift, but I extracted the part relevant to the declaration of allCases:

(var_decl implicit "allCases" type='[MyEnum]' interface type='[MyEnum]' access=internal type storage_kind=computed
  (accessor_decl implicit 'anonname=0x7fa86f015c28' interface type='(MyEnum.Type) -> () -> [MyEnum]' access=internal type getter_for=allCases
    (parameter_list
      (parameter "self" interface type='MyEnum.Type'))
    (parameter_list)
    (brace_stmt
      (return_stmt implicit
        (array_expr type='[MyEnum]'
          (dot_syntax_call_expr implicit type='MyEnum' nothrow
            (declref_expr implicit type='(MyEnum.Type) -> MyEnum' decl=moduletest.(file).MyEnum.foo@/Users/bruno.rocha/Desktop/moduletest.swift:2:10 function_ref=double)
            (type_expr implicit type='MyEnum.Type' typerepr='<<NULL>>'))
          (dot_syntax_call_expr implicit type='MyEnum' nothrow
            (declref_expr implicit type='(MyEnum.Type) -> MyEnum' decl=moduletest.(file).MyEnum.bar@/Users/bruno.rocha/Desktop/moduletest.swift:3:10 function_ref=double)
            (type_expr implicit type='MyEnum.Type' typerepr='<<NULL>>')))))))

ASTs are very verbose, but the names help us understand what this actually means. We have a declaration of an allCases property (var_decl), which (brace_stmt) returns (return_stmt) a MyEnum array (array_expr) containing .foo (defined by the dot_syntax_call_expr of implicit type MyEnum followed by the declref_expr reference of MyEnum.foo) and .bar (same as before).

Verbosity aside, this is the same as the return [.foo, .bar] shown above. But where is this code injection happening?

Debugging the Swift Compiler

Since CaseIterable is a relatively simple protocol, we can likely uncover its internals by searching the open source Swift repository on GitHub. I did this and got only about 2 pages of references - most being unit tests.

One of the results is a reference to the actual thing: A suspicious method named deriveCaseIterable_enum_getter in a file named DerivedConformanceCaseIterable.cpp that takes a property's body and appends some content to it. Bingo!

But before analyzing what this method does, I'm interested in knowing how the compiler got here in the first place.

By making a fork of the Swift compiler and building it in debug mode, we're able to attach lldb to it, breakpoint this method and call bt to print its backtrace.

lldb -- /swift-fork/build/Ninja-ReleaseAssert+swift-DebugAssert/swift-macosx-x86_64/bin/swiftc -dump-ast enum.swift

(Note) Since I researched this all by myself, some assumptions might not be fully correct. Feel free to correct me if you know the Swift compiler!

If you take a look at the file in question, you'll find that deriveCaseIterable_enum_getter isn't called directly. Instead, it gets passed as a reference from another method called deriveCaseIterable. This means that backtracing won't reveal the information we want - so instead of backtracing it directly, I'll backtrace deriveCaseIterable itself.

b DerivedConformanceCaseIterable.cpp:82
run
Process 15104 stopped
frame #0: 0x0000000101ca809e DerivedConformance::deriveCaseIterable(this=0x00007ffeefbf6d60, requirement=0x000000010d97d838) at DerivedConformanceCaseIterable.cpp:82
bt

The backtrace goes a long way, but if we take the last seven stack nodes, we end up with:

DerivedConformance::deriveCaseIterable(this=0x00007ffeefbf6d60, requirement=0x000000010d97d838) at DerivedConformanceCaseIterable.cpp:82
TypeChecker::deriveProtocolRequirement(this=0x00007ffeefbf95f0, DC=0x000000010c8d9730, TypeDecl=0x000000010c8d9718, Requirement=0x000000010d97d838) at TypeCheckProtocol.cpp:5137
ConformanceChecker::resolveWitnessViaDerivation(this=0x00007ffeefbf82d0, requirement=0x000000010d97d838) at TypeCheckProtocol.cpp:3081
ConformanceChecker::checkConformance(this=0x00007ffeefbf82d0, Kind=ErrorFixIt) at TypeCheckProtocol.cpp:3665
MultiConformanceChecker::checkIndividualConformance(this=0x00007ffeefbf8058, conformance=0x000000010c8db5e8, issueFixit=true) at TypeCheckProtocol.cpp:1707
MultiConformanceChecker::checkAllConformances(this=0x00007ffeefbf8058) at TypeCheckProtocol.cpp:1328
TypeChecker::checkConformancesInContext(this=0x00007ffeefbf95f0, dc=0x000000010c8d9730, idc=0x000000010c8d9790) at TypeCheckProtocol.cpp:4720

After a quick inspection at each of these symbol's files, we can see that after parsing the file's structure, the compiler starts running a couple workflows to determine if all protocols and conditions are being conformed correctly (take a look at the files from the backtrace to see them yourself!).

At checkConformancesInContext, the compiler has access to a context (our enum's declaraction). It extracts an array of conformances from it (CaseIterable.allCases in this case) and calls checkAllConformances.

checkAllConformances loops the array of conformances and calls checkIndividualConformance for each of them. If the requirements are not being satisfied, compilation warnings/errors are dispatched.

checkIndividualConformance seems to make superficial checks to the conformance, such as checking if it's using a class protocol outside a class, or if it's an OBJ-C object trying to conform to a Swift protocol. If the compiler is still incapable of confirming the requirements (because we're literally missing an entire property), checkConformance is called.

checkConformance will attempt to validate a protocol through a few procedures. This is where my subpar knowledge of compilers leaves me hanging, but I was able to grasp the meaning of the procedure that matters to us: resolveWitnessViaDerivation. This is where requirements try to be confirmed by injecting the relevant missing code.

Deriving Protocols

But before resolveWitnessViaDerivation is called, two important methods that are not in the backtrace are called: getDerivableRequirement and derivesProtocolRequirement. You can see them here.

getDerivableRequirement determines if a certain requirement even supports this kind of code generation in first place. If the name of the requirement matches a requirement in a known protocol, we proceed:

// CaseIterable.allValues
    if (name.isSimpleName(ctx.Id_allCases))
        return getRequirement(KnownProtocolKind::CaseIterable);

The getRequirement from the return statement then calls derivesProtocolRequirement, which will try to match the requirement with the protocol's own set of rules.

For the "CaseIterable inside enums" feature, the rules are:

case KnownProtocolKind::CaseIterable:
    return !enumDecl->hasPotentiallyUnavailableCaseValue()
           && enumDecl->hasOnlyCasesWithoutAssociatedValues();

To be honest, I'm not really sure what an PotentiallyUnavailableCaseValue refers to (Update: Łukasz Grzywacz discovered that this is checking for cases inside #available conditions!), but the second condition is something we know: The derivation will only work if your cases don't contain associated values, as the compiler can't possibly know which value you want to be there. That's not the case for MyEnum, so we're good!

With the derivation being possible, we head back to the backtrace as deriveProtocolRequirement gets called. The compiler will now attempt to generate the remaining code.

The same object/protocol name matching happens in this method, but in order to actually perform the code generation. For CaseIterable, this results in deriveCaseIterable being called.

case KnownProtocolKind::CaseIterable:
    return derived.deriveCaseIterable(Requirement);

deriveCaseIterable performs a few more checks, like seeing if protocol was added in an extension (which is a no-no for derivation). If all goes well, it defines an empty allCases property and finally calls the method that fills it: the deriveCaseIterable_enum_getter that we first saw.

auto *returnTy = computeAllCasesType(Nominal); // [MyEnum]

VarDecl *propDecl;
PatternBindingDecl *pbDecl;
std::tie(propDecl, pbDecl) = declareDerivedProperty(TC.Context.Id_allCases, returnTy, returnTy, *isStatic=*/true, /*isFinal=*/true);

// Define the getter.
auto *getterDecl = addGetterToReadOnlyDerivedProperty(TC, propDecl, returnTy);

// Set the getter's body.
getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);

This is the definition of deriveCaseIterable_enum_getter:

void deriveCaseIterable_enum_getter(AbstractFunctionDecl *funcDecl) {
    auto *parentDC = funcDecl->getDeclContext();
    auto *parentEnum = parentDC->getSelfEnumDecl();
    auto enumTy = parentDC->getDeclaredTypeInContext();
    auto &C = parentDC->getASTContext();

    SmallVector<Expr *, 8> elExprs;
    for (EnumElementDecl *elt : parentEnum->getAllElements()) {
        auto *ref = new (C) DeclRefExpr(elt, DeclNameLoc(), /*implicit*/true);
        auto *base = TypeExpr::createImplicit(enumTy, C);
        auto *apply = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base);
        elExprs.push_back(apply);
    }
    auto *arrayExpr = ArrayExpr::create(C, SourceLoc(), elExprs, {}, SourceLoc());

    auto *returnStmt = new (C) ReturnStmt(SourceLoc(), arrayExpr);
    auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc());
    funcDecl->setBody(body);
}

The interesting thing about this method is that it is a lot less complicated than a non-compiler person like me would expect. Because we're in the middle of the compilation, the compiler has access to a mutable version of our AST seen above and a direct reference to the node that represents the main declaration of our CaseIterable-semi-conformant enum. To add allCases to it, we just literally write it in AST form and append it to the enum's node.

Although C++ isn't the easiest language to understand, you can see that this is just iterating the enum's cases and creating a return statement as a bunch of expressions that match the expressions of the AST we've seen above. The parameter funcDecl is the empty body of allClasses - which was generated by deriveCaseIterable. After the expression is generated, it gets applied to the body.

Fun time: Adding more properties to CaseIterable

Now that we've figured out how it works, how about adding our own properties to it? I think my fake CaseIterable would benefit from having a first property that returned the first defined case.

From the Standard Library's point of view, this is pretty straight-forward as we just need to define a new static var:

public protocol CaseIterable {
    /// A type that can represent a collection of all values of this type.
    associatedtype AllCases: Collection where AllCases.Element == Self

    /// A collection of all values of this type.
    static var allCases: AllCases { get }

    /// The first case of this type.
    static var first: Self { get }
}

But the users of this protocol don't need to fill the first property if it's being used on an enum, so I want this property to be derived by the compiler as well.

To do this, I'll first clone the deriveCaseIterable_enum_getter method that generates the case array and modify it so the expression returns the first case instead of the array:

void deriveCaseIterable_first(AbstractFunctionDecl *funcDecl) {
    auto *parentDC = funcDecl->getDeclContext();
    auto *parentEnum = parentDC->getSelfEnumDecl();
    auto enumTy = parentDC->getDeclaredTypeInContext();
    auto &C = parentDC->getASTContext();

    EnumElementDecl *elt = parentEnum->getAllElements().front();
    auto *ref = new (C) DeclRefExpr(elt, DeclNameLoc(), /*implicit*/true);
    auto *base = TypeExpr::createImplicit(enumTy, C);
    auto *dotExpr = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base);

    auto *returnStmt = new (C) ReturnStmt(SourceLoc(), dotExpr);
    auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc());
    funcDecl->setBody(body);
}

With that done, we now need to make this method get called. We've seen previously that deriveCaseIterable_enum_getter gets called by deriveCaseIterable() - if we inspect the contents of that method, we'll find that it's able to detect the name of the parameter being checked:

ValueDecl *DerivedConformance::deriveCaseIterable(ValueDecl *requirement) {
    // Deleted to make stuff shorter: Some pre-checks

    if (requirement->getBaseName() != TC.Context.Id_allCases) {
        // Deleted to make stuff shorter: Throw compilation error
    }

    auto *returnTy = computeAllCasesType(Nominal); // Define the [MyEnum] return type
    // Deleted to make stuff shorter: Define allCases's getter
    declareDerivedProperty(TC.Context.Id_allCases, returnTy, returnTy, *isStatic=*/true, /*isFinal=*/true);
    // Deleted to make stuff shorter: Prepare allCases's getter
    getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);
}

After searching a bit, I've found that the Id_allCases property comes from a file named KnownIdentifiers.def. I've edited it to add a new Id_first property for our feature. I've also added Id_first to the getDerivableRequirement() method mentioned above so the compiler knows that this property can be derived.

For this feature to work, we need to keep the old allCases logic but add an else block to treat the new first requirement.

After creating a block for first, we need to change returnTy to be MyEnum instead of [MyEnum] and have declareDerivedProperty() use Id_first as the property name instead of Id_allCases, and finally, have setBodySynthesizer use the new method.

To make returnTy be MyEnum, I just looked up how computeAllCasesType() was retrieving the enum's type, which ended up being by calling Nominal->getDeclaredInterfaceType();.

After some coding, the final method looks like this: (You can see the full version here.)

ValueDecl *DerivedConformance::deriveCaseIterable(ValueDecl *requirement) {
    // Deleted to make stuff shorter: Some pre-checks

    Type returnTy;
    Identifier propertyId;

    if (requirement->getBaseName() == TC.Context.Id_allCases) {
        returnTy = computeAllCasesType(Nominal);
        propertyId = TC.Context.Id_allCases;
    } else if (requirement->getBaseName() == TC.Context.Id_first) {
        returnTy = Nominal->getDeclaredInterfaceType();
        propertyId = TC.Context.Id_first;
    } else {
        // Deleted to make stuff shorter: Throw compilation error
    }

    // Deleted to make stuff shorter: Define allCases's getter
    declareDerivedProperty(propertyId, returnTy, returnTy, /*isStatic=*/true, /*isFinal=*/true);

    if (requirement->getBaseName() == TC.Context.Id_allCases) {
        getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);
    } else {
        getterDecl->setBodySynthesizer(&deriveCaseIterable_first);
    }
}

After building the compiler, we can get a CaseIterable enum's first case without explicitely defining it!

enum MyEnum: CaseIterable {
    case foo
    case bar
}

print(MyEnum.first) // .foo

Conclusion

Compilers are scary, and the Swift one is no different. I'm still trying to figure out how most things work (If you're a compiler expert, I'm looking for tips on great books and resources!), but one thing that I've said before on my posts is that knowing the internals of a language can really help you write efficient code. I had a lot of fun inspecting this feature and hope it was useful to you in some way.

Follow me on my Twitter - @rockbruno_, and let me know of any suggestions and corrections you want to share.

References and Good reads

The Swift Source Code

Update: People were curious on how first acts if the enum has no cases: It crashes! We can fix it by adding a new rule to derivesProtocolRequirement that returns false if the requirement is first and the enum is empty - which would make the compiler return a does not conform to CaseIterable error in that case.