Skip to content

Commit 4671dc5

Browse files
committed
[AutoDiff upstream] Add TBDGen for AutoDiff symbols.
Emit symbols for `@differentiable` and `@derivative` declaration attributes: - Differentiability witness symbols. - Derivative function (JVP/VJP) symbols. - Linear map (differential/pullback) symbols. Add TBDGen test.
1 parent aa66cce commit 4671dc5

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

lib/TBDGen/TBDGen.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,95 @@ void TBDGenVisitor::addConformances(DeclContext *DC) {
502502
}
503503
}
504504

505+
void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
506+
AutoDiffConfig config,
507+
AutoDiffLinearMapKind kind) {
508+
auto &ctx = original->getASTContext();
509+
auto declRef =
510+
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
511+
512+
if (!declRef.isSerialized())
513+
return;
514+
// Linear maps are public only when the original function is serialized.
515+
if (!declRef.isSerialized())
516+
return;
517+
// Differential functions are emitted only when forward-mode is enabled.
518+
if (kind == AutoDiffLinearMapKind::Differential &&
519+
!ctx.LangOpts.EnableExperimentalForwardModeDifferentiation)
520+
return;
521+
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
522+
config.parameterIndices,
523+
original->getInterfaceType()->castTo<AnyFunctionType>());
524+
Mangle::ASTMangler mangler;
525+
AutoDiffConfig silConfig{loweredParamIndices, config.resultIndices,
526+
config.derivativeGenericSignature};
527+
std::string linearMapName =
528+
mangler.mangleAutoDiffLinearMapHelper(declRef.mangle(), kind, silConfig);
529+
addSymbol(linearMapName);
530+
}
531+
532+
void TBDGenVisitor::addAutoDiffDerivativeFunction(
533+
AbstractFunctionDecl *original, IndexSubset *parameterIndices,
534+
GenericSignature derivativeGenericSignature,
535+
AutoDiffDerivativeFunctionKind kind) {
536+
auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get(
537+
kind, parameterIndices, derivativeGenericSignature,
538+
original->getASTContext());
539+
auto declRef =
540+
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
541+
addSymbol(declRef.asAutoDiffDerivativeFunction(assocFnId));
542+
}
543+
544+
void TBDGenVisitor::addDifferentiabilityWitness(
545+
AbstractFunctionDecl *original, IndexSubset *astParameterIndices,
546+
IndexSubset *resultIndices, GenericSignature derivativeGenericSignature) {
547+
bool foreign = requiresForeignEntryPoint(original);
548+
auto declRef = SILDeclRef(original).asForeign(foreign);
549+
550+
// Skip symbol emission for original functions that do not have public
551+
// linkage. Exclude original functions that require a foreign entry point with
552+
// `public_external` linkage.
553+
auto originalLinkage = declRef.getLinkage(ForDefinition);
554+
if (foreign)
555+
originalLinkage = stripExternalFromLinkage(originalLinkage);
556+
if (originalLinkage != SILLinkage::Public)
557+
return;
558+
559+
auto *silParamIndices = autodiff::getLoweredParameterIndices(
560+
astParameterIndices,
561+
original->getInterfaceType()->castTo<AnyFunctionType>());
562+
563+
auto originalMangledName = declRef.mangle();
564+
AutoDiffConfig config{silParamIndices, resultIndices,
565+
derivativeGenericSignature};
566+
SILDifferentiabilityWitnessKey key(originalMangledName, config);
567+
568+
Mangle::ASTMangler mangler;
569+
auto mangledName = mangler.mangleSILDifferentiabilityWitnessKey(key);
570+
addSymbol(mangledName);
571+
}
572+
573+
void TBDGenVisitor::addDerivativeConfiguration(AbstractFunctionDecl *original,
574+
AutoDiffConfig config) {
575+
auto inserted = AddedDerivatives.insert({original, config});
576+
if (!inserted.second)
577+
return;
578+
579+
addAutoDiffLinearMapFunction(original, config,
580+
AutoDiffLinearMapKind::Differential);
581+
addAutoDiffLinearMapFunction(original, config,
582+
AutoDiffLinearMapKind::Pullback);
583+
addAutoDiffDerivativeFunction(original, config.parameterIndices,
584+
config.derivativeGenericSignature,
585+
AutoDiffDerivativeFunctionKind::JVP);
586+
addAutoDiffDerivativeFunction(original, config.parameterIndices,
587+
config.derivativeGenericSignature,
588+
AutoDiffDerivativeFunctionKind::VJP);
589+
addDifferentiabilityWitness(original, config.parameterIndices,
590+
config.resultIndices,
591+
config.derivativeGenericSignature);
592+
}
593+
505594
/// Determine whether dynamic replacement should be emitted for the allocator or
506595
/// the initializer given a decl.
507596
/// The rule is that structs and convenience init of classes emit a
@@ -565,6 +654,22 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
565654
addSymbol(SILDeclRef(AFD).asForeign());
566655
}
567656

657+
// Add derivative function symbols.
658+
for (const auto *differentiableAttr :
659+
AFD->getAttrs().getAttributes<DifferentiableAttr>())
660+
addDerivativeConfiguration(
661+
AFD,
662+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
663+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
664+
differentiableAttr->getDerivativeGenericSignature()));
665+
for (const auto *derivativeAttr :
666+
AFD->getAttrs().getAttributes<DerivativeAttr>())
667+
addDerivativeConfiguration(
668+
derivativeAttr->getOriginalFunction(),
669+
AutoDiffConfig(derivativeAttr->getParameterIndices(),
670+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
671+
AFD->getGenericSignature()));
672+
568673
visitDefaultArguments(AFD, AFD->getParameters());
569674
}
570675

@@ -617,6 +722,15 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) {
617722
ASD->visitEmittedAccessors([&](AccessorDecl *accessor) {
618723
visitFuncDecl(accessor);
619724
});
725+
726+
// Add derivative function symbols.
727+
for (const auto *differentiableAttr :
728+
ASD->getAttrs().getAttributes<DifferentiableAttr>())
729+
addDerivativeConfiguration(
730+
ASD->getAccessor(AccessorKind::Get),
731+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
732+
IndexSubset::get(ASD->getASTContext(), 1, {0}),
733+
differentiableAttr->getDerivativeGenericSignature()));
620734
}
621735

622736
void TBDGenVisitor::visitVarDecl(VarDecl *VD) {

lib/TBDGen/TBDGenVisitor.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class TBDGenVisitor : public ASTVisitor<TBDGenVisitor> {
7070
ModuleDecl *SwiftModule;
7171
const TBDGenOptions &Opts;
7272

73+
/// A set of original function and derivative configuration pairs for which
74+
/// derivative symbols have been emitted.
75+
///
76+
/// Used to deduplicate derivative symbol emission for `@differentiable` and
77+
/// `@derivative` attributes.
78+
llvm::DenseSet<std::pair<AbstractFunctionDecl *, AutoDiffConfig>>
79+
AddedDerivatives;
80+
7381
private:
7482
std::vector<Decl*> DeclStack;
7583
std::unique_ptr<std::map<std::string, InstallNameStore>>
@@ -98,6 +106,34 @@ class TBDGenVisitor : public ASTVisitor<TBDGenVisitor> {
98106
void addAssociatedConformanceDescriptor(AssociatedConformance conformance);
99107
void addBaseConformanceDescriptor(BaseConformance conformance);
100108

109+
/// Adds the symbol for the linear map function of the given kind associated
110+
/// with the given original function and derivative function configuration.
111+
void addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
112+
AutoDiffConfig config,
113+
AutoDiffLinearMapKind kind);
114+
115+
/// Adds the symbol for the autodiff function of the given kind associated
116+
/// with the given original function, parameter indices, and derivative
117+
/// generic signature.
118+
void
119+
addAutoDiffDerivativeFunction(AbstractFunctionDecl *original,
120+
IndexSubset *parameterIndices,
121+
GenericSignature derivativeGenericSignature,
122+
AutoDiffDerivativeFunctionKind kind);
123+
124+
/// Adds the symbol for the differentiability witness associated with the
125+
/// given original function, AST parameter indices, result indices, and
126+
/// derivative generic signature.
127+
void addDifferentiabilityWitness(AbstractFunctionDecl *original,
128+
IndexSubset *astParameterIndices,
129+
IndexSubset *resultIndices,
130+
GenericSignature derivativeGenericSignature);
131+
132+
/// Adds symbols associated with the given original function and
133+
/// derivative function configuration.
134+
void addDerivativeConfiguration(AbstractFunctionDecl *original,
135+
AutoDiffConfig config);
136+
101137
public:
102138
TBDGenVisitor(llvm::MachO::InterfaceFile &symbols,
103139
llvm::MachO::TargetList targets, StringSet *stringSymbols,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s
2+
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s -O
3+
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing
4+
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing -O
5+
6+
import _Differentiation
7+
8+
@differentiable
9+
public func topLevelDifferentiable(_ x: Float, _ y: Float) -> Float { x }
10+
11+
public func topLevelHasDerivative<T: Differentiable>(_ x: T) -> T {
12+
x
13+
}
14+
15+
@derivative(of: topLevelHasDerivative)
16+
public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
17+
value: T, pullback: (T.TangentVector) -> T.TangentVector
18+
) {
19+
fatalError()
20+
}
21+
22+
struct Struct: Differentiable {
23+
var stored: Float
24+
25+
// Test property.
26+
@differentiable
27+
public var property: Float {
28+
stored
29+
}
30+
31+
// Test initializer.
32+
@differentiable
33+
public init(_ x: Float) {
34+
stored = x
35+
}
36+
37+
// Test method.
38+
public func method(x: Float, y: Float) -> Float { x }
39+
40+
@derivative(of: method)
41+
public func jvpMethod(x: Float, y: Float) -> (
42+
value: Float, differential: (TangentVector, Float, Float) -> Float
43+
) {
44+
fatalError()
45+
}
46+
47+
// Test subscript.
48+
public subscript(x: Float) -> Float { x }
49+
50+
@derivative(of: subscript)
51+
public func vjpSubscript(x: Float) -> (
52+
value: Float, pullback: (Float) -> (TangentVector, Float)
53+
) {
54+
fatalError()
55+
}
56+
}

0 commit comments

Comments
 (0)