Skip to content

[AutoDiff upstream] Add differentiation transform. #30781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ struct AutoDiffConfig {
SWIFT_DEBUG_DUMP;
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
const SILAutoDiffIndices &indices) {
indices.print(s);
return s;
}

/// A semantic function result type: either a formal function result type or
/// an `inout` parameter type. Used in derivative function type calculation.
struct AutoDiffSemanticFunctionResultType {
Expand Down
30 changes: 30 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,36 @@ ERROR(not_constant_evaluable, none, "not constant evaluable", ())
ERROR(constexpr_imported_func_not_onone, none, "imported constant evaluable "
"function '%0' must be annotated '@_optimize(none)'", (StringRef))

// Differentiation transform diagnostics
ERROR(autodiff_internal_swift_not_imported,none,
"Automatic differentiation internal error: the Swift module is not "
"imported", ())
ERROR(autodiff_differentiation_module_not_imported,none,
"Automatic differentiation requires the '_Differentiation' module to be "
"imported", ())
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
"conversion to '@differentiable(linear)' function type is not yet "
"supported", ())
ERROR(autodiff_function_not_differentiable_error,none,
"function is not differentiable", ())
ERROR(autodiff_expression_not_differentiable_error,none,
"expression is not differentiable", ())
NOTE(autodiff_expression_not_differentiable_note,none,
"expression is not differentiable", ())
NOTE(autodiff_when_differentiating_function_call,none,
"when differentiating this function call", ())
NOTE(autodiff_when_differentiating_function_definition,none,
"when differentiating this function definition", ())
NOTE(autodiff_implicitly_inherited_differentiable_attr_here,none,
"differentiability required by the corresponding protocol requirement "
"here", ())
NOTE(autodiff_jvp_control_flow_not_supported,none,
"forward-mode differentiation does not yet support control flow", ())
NOTE(autodiff_control_flow_not_supported,none,
"cannot differentiate unsupported control flow", ())
NOTE(autodiff_missing_return,none,
"missing return for differentiation", ())

ERROR(non_physical_addressof,none,
"addressof only works with purely physical lvalues; "
"use 'withUnsafePointer' or 'withUnsafeBytes' unless you're implementing "
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ namespace swift {
/// `@differentiable` declaration attribute, etc.
bool EnableExperimentalDifferentiableProgramming = false;

/// Whether to enable forward mode differentiation.
bool EnableExperimentalForwardModeDifferentiation = false;

/// Whether to enable experimental `AdditiveArithmetic` derived
/// conformances.
bool EnableExperimentalAdditiveArithmeticDerivedConformances = false;
Expand Down
8 changes: 8 additions & 0 deletions include/swift/Option/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,14 @@ def disable_bridging_pch : Flag<["-"], "disable-bridging-pch">,
HelpText<"Disable automatic generation of bridging PCH files">;

// Experimental feature options

// Note: this flag will be removed when JVP/differential generation in the
// differentiation transform is robust.
def enable_experimental_forward_mode_differentiation :
Flag<["-"], "enable-experimental-forward-mode-differentiation">,
Flags<[FrontendOption]>,
HelpText<"Enable experimental forward mode differentiation">;

def enable_experimental_additive_arithmetic_derivation :
Flag<["-"], "enable-experimental-additive-arithmetic-derivation">,
Flags<[FrontendOption]>,
Expand Down
5 changes: 5 additions & 0 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class SILDifferentiabilityWitness
bool isSerialized() const { return IsSerialized; }
const DeclAttribute *getAttribute() const { return Attribute; }

/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-893): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;

/// Verify that the differentiability witness is well-formed.
void verify(const SILModule &module) const;

Expand Down
2 changes: 2 additions & 0 deletions include/swift/SILOptimizer/PassManager/Passes.def
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ PASS(CopyForwarding, "copy-forwarding",
"Copy Forwarding to Remove Redundant Copies")
PASS(CopyPropagation, "copy-propagation",
"Copy propagation to Remove Redundant SSA Copies")
PASS(Differentiation, "differentiation",
"Automatic Differentiation")
PASS(EpilogueARCMatcherDumper, "sil-epilogue-arc-dumper",
"Print Epilogue retains of Returned Values and Argument releases")
PASS(EpilogueRetainReleaseMatcherDumper, "sil-epilogue-retain-release-dumper",
Expand Down
Loading