Skip to content

Commit aa66cce

Browse files
committed
[AutoDiff upstream] Add differentiation transform.
The differentiation transform does the following: - Canonicalizes differentiability witnesses by filling in missing derivative function entries. - Canonicalizes `differentiable_function` instructions by filling in missing derivative function operands. - If necessary, performs automatic differentiation: generating derivative functions for original functions. - When encountering non-differentiability code, produces a diagnostic and errors out. Partially resolves TF-1211: add the main canonicalization loop. To incrementally stage changes, derivative functions are currently created with empty bodies that fatal error with a nice message. Derivative emitters will be upstreamed separately.
1 parent 1308fc6 commit aa66cce

File tree

24 files changed

+1623
-22
lines changed

24 files changed

+1623
-22
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ struct AutoDiffConfig {
227227
SWIFT_DEBUG_DUMP;
228228
};
229229

230+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
231+
const SILAutoDiffIndices &indices) {
232+
indices.print(s);
233+
return s;
234+
}
235+
230236
/// A semantic function result type: either a formal function result type or
231237
/// an `inout` parameter type. Used in derivative function type calculation.
232238
struct AutoDiffSemanticFunctionResultType {

include/swift/AST/DiagnosticsSIL.def

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,36 @@ ERROR(not_constant_evaluable, none, "not constant evaluable", ())
445445
ERROR(constexpr_imported_func_not_onone, none, "imported constant evaluable "
446446
"function '%0' must be annotated '@_optimize(none)'", (StringRef))
447447

448+
// Differentiation transform diagnostics
449+
ERROR(autodiff_internal_swift_not_imported,none,
450+
"Automatic differentiation internal error: the Swift module is not "
451+
"imported", ())
452+
ERROR(autodiff_differentiation_module_not_imported,none,
453+
"Automatic differentiation requires the '_Differentiation' module to be "
454+
"imported", ())
455+
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
456+
"conversion to '@differentiable(linear)' function type is not yet "
457+
"supported", ())
458+
ERROR(autodiff_function_not_differentiable_error,none,
459+
"function is not differentiable", ())
460+
ERROR(autodiff_expression_not_differentiable_error,none,
461+
"expression is not differentiable", ())
462+
NOTE(autodiff_expression_not_differentiable_note,none,
463+
"expression is not differentiable", ())
464+
NOTE(autodiff_when_differentiating_function_call,none,
465+
"when differentiating this function call", ())
466+
NOTE(autodiff_when_differentiating_function_definition,none,
467+
"when differentiating this function definition", ())
468+
NOTE(autodiff_implicitly_inherited_differentiable_attr_here,none,
469+
"differentiability required by the corresponding protocol requirement "
470+
"here", ())
471+
NOTE(autodiff_jvp_control_flow_not_supported,none,
472+
"forward-mode differentiation does not yet support control flow", ())
473+
NOTE(autodiff_control_flow_not_supported,none,
474+
"cannot differentiate unsupported control flow", ())
475+
NOTE(autodiff_missing_return,none,
476+
"missing return for differentiation", ())
477+
448478
ERROR(non_physical_addressof,none,
449479
"addressof only works with purely physical lvalues; "
450480
"use 'withUnsafePointer' or 'withUnsafeBytes' unless you're implementing "

include/swift/Basic/LangOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ namespace swift {
327327
/// `@differentiable` declaration attribute, etc.
328328
bool EnableExperimentalDifferentiableProgramming = false;
329329

330+
/// Whether to enable forward mode differentiation.
331+
bool EnableExperimentalForwardModeDifferentiation = false;
332+
330333
/// Whether to enable experimental `AdditiveArithmetic` derived
331334
/// conformances.
332335
bool EnableExperimentalAdditiveArithmeticDerivedConformances = false;

include/swift/Option/Options.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,14 @@ def disable_bridging_pch : Flag<["-"], "disable-bridging-pch">,
498498
HelpText<"Disable automatic generation of bridging PCH files">;
499499

500500
// Experimental feature options
501+
502+
// Note: this flag will be removed when JVP/differential generation in the
503+
// differentiation transform is robust.
504+
def enable_experimental_forward_mode_differentiation :
505+
Flag<["-"], "enable-experimental-forward-mode-differentiation">,
506+
Flags<[FrontendOption]>,
507+
HelpText<"Enable experimental forward mode differentiation">;
508+
501509
def enable_experimental_additive_arithmetic_derivation :
502510
Flag<["-"], "enable-experimental-additive-arithmetic-derivation">,
503511
Flags<[FrontendOption]>,

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ class SILDifferentiabilityWitness
132132
bool isSerialized() const { return IsSerialized; }
133133
const DeclAttribute *getAttribute() const { return Attribute; }
134134

135+
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
136+
// TODO(TF-893): This is a temporary shim for incremental removal of
137+
// `SILAutoDiffIndices`. Eventually remove this.
138+
SILAutoDiffIndices getSILAutoDiffIndices() const;
139+
135140
/// Verify that the differentiability witness is well-formed.
136141
void verify(const SILModule &module) const;
137142

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ PASS(CopyForwarding, "copy-forwarding",
120120
"Copy Forwarding to Remove Redundant Copies")
121121
PASS(CopyPropagation, "copy-propagation",
122122
"Copy propagation to Remove Redundant SSA Copies")
123+
PASS(Differentiation, "differentiation",
124+
"Automatic Differentiation")
123125
PASS(EpilogueARCMatcherDumper, "sil-epilogue-arc-dumper",
124126
"Print Epilogue retains of Returned Values and Argument releases")
125127
PASS(EpilogueRetainReleaseMatcherDumper, "sil-epilogue-retain-release-dumper",

0 commit comments

Comments
 (0)