@@ -502,6 +502,95 @@ void TBDGenVisitor::addConformances(DeclContext *DC) {
502
502
}
503
503
}
504
504
505
+ void TBDGenVisitor::addAutoDiffLinearMapFunction (AbstractFunctionDecl *original,
506
+ AutoDiffConfig config,
507
+ AutoDiffLinearMapKind kind) {
508
+ auto &ctx = original->getASTContext ();
509
+ auto declRef =
510
+ SILDeclRef (original).asForeign (requiresForeignEntryPoint (original));
511
+
512
+ if (!declRef.isSerialized ())
513
+ return ;
514
+ // Linear maps are public only when the original function is serialized.
515
+ if (!declRef.isSerialized ())
516
+ return ;
517
+ // Differential functions are emitted only when forward-mode is enabled.
518
+ if (kind == AutoDiffLinearMapKind::Differential &&
519
+ !ctx.LangOpts .EnableExperimentalForwardModeDifferentiation )
520
+ return ;
521
+ auto *loweredParamIndices = autodiff::getLoweredParameterIndices (
522
+ config.parameterIndices ,
523
+ original->getInterfaceType ()->castTo <AnyFunctionType>());
524
+ Mangle::ASTMangler mangler;
525
+ AutoDiffConfig silConfig{loweredParamIndices, config.resultIndices ,
526
+ config.derivativeGenericSignature };
527
+ std::string linearMapName =
528
+ mangler.mangleAutoDiffLinearMapHelper (declRef.mangle (), kind, silConfig);
529
+ addSymbol (linearMapName);
530
+ }
531
+
532
+ void TBDGenVisitor::addAutoDiffDerivativeFunction (
533
+ AbstractFunctionDecl *original, IndexSubset *parameterIndices,
534
+ GenericSignature derivativeGenericSignature,
535
+ AutoDiffDerivativeFunctionKind kind) {
536
+ auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get (
537
+ kind, parameterIndices, derivativeGenericSignature,
538
+ original->getASTContext ());
539
+ auto declRef =
540
+ SILDeclRef (original).asForeign (requiresForeignEntryPoint (original));
541
+ addSymbol (declRef.asAutoDiffDerivativeFunction (assocFnId));
542
+ }
543
+
544
+ void TBDGenVisitor::addDifferentiabilityWitness (
545
+ AbstractFunctionDecl *original, IndexSubset *astParameterIndices,
546
+ IndexSubset *resultIndices, GenericSignature derivativeGenericSignature) {
547
+ bool foreign = requiresForeignEntryPoint (original);
548
+ auto declRef = SILDeclRef (original).asForeign (foreign);
549
+
550
+ // Skip symbol emission for original functions that do not have public
551
+ // linkage. Exclude original functions that require a foreign entry point with
552
+ // `public_external` linkage.
553
+ auto originalLinkage = declRef.getLinkage (ForDefinition);
554
+ if (foreign)
555
+ originalLinkage = stripExternalFromLinkage (originalLinkage);
556
+ if (originalLinkage != SILLinkage::Public)
557
+ return ;
558
+
559
+ auto *silParamIndices = autodiff::getLoweredParameterIndices (
560
+ astParameterIndices,
561
+ original->getInterfaceType ()->castTo <AnyFunctionType>());
562
+
563
+ auto originalMangledName = declRef.mangle ();
564
+ AutoDiffConfig config{silParamIndices, resultIndices,
565
+ derivativeGenericSignature};
566
+ SILDifferentiabilityWitnessKey key (originalMangledName, config);
567
+
568
+ Mangle::ASTMangler mangler;
569
+ auto mangledName = mangler.mangleSILDifferentiabilityWitnessKey (key);
570
+ addSymbol (mangledName);
571
+ }
572
+
573
+ void TBDGenVisitor::addDerivativeConfiguration (AbstractFunctionDecl *original,
574
+ AutoDiffConfig config) {
575
+ auto inserted = AddedDerivatives.insert ({original, config});
576
+ if (!inserted.second )
577
+ return ;
578
+
579
+ addAutoDiffLinearMapFunction (original, config,
580
+ AutoDiffLinearMapKind::Differential);
581
+ addAutoDiffLinearMapFunction (original, config,
582
+ AutoDiffLinearMapKind::Pullback);
583
+ addAutoDiffDerivativeFunction (original, config.parameterIndices ,
584
+ config.derivativeGenericSignature ,
585
+ AutoDiffDerivativeFunctionKind::JVP);
586
+ addAutoDiffDerivativeFunction (original, config.parameterIndices ,
587
+ config.derivativeGenericSignature ,
588
+ AutoDiffDerivativeFunctionKind::VJP);
589
+ addDifferentiabilityWitness (original, config.parameterIndices ,
590
+ config.resultIndices ,
591
+ config.derivativeGenericSignature );
592
+ }
593
+
505
594
// / Determine whether dynamic replacement should be emitted for the allocator or
506
595
// / the initializer given a decl.
507
596
// / The rule is that structs and convenience init of classes emit a
@@ -565,6 +654,22 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
565
654
addSymbol (SILDeclRef (AFD).asForeign ());
566
655
}
567
656
657
+ // Add derivative function symbols.
658
+ for (const auto *differentiableAttr :
659
+ AFD->getAttrs ().getAttributes <DifferentiableAttr>())
660
+ addDerivativeConfiguration (
661
+ AFD,
662
+ AutoDiffConfig (differentiableAttr->getParameterIndices (),
663
+ IndexSubset::get (AFD->getASTContext (), 1 , {0 }),
664
+ differentiableAttr->getDerivativeGenericSignature ()));
665
+ for (const auto *derivativeAttr :
666
+ AFD->getAttrs ().getAttributes <DerivativeAttr>())
667
+ addDerivativeConfiguration (
668
+ derivativeAttr->getOriginalFunction (),
669
+ AutoDiffConfig (derivativeAttr->getParameterIndices (),
670
+ IndexSubset::get (AFD->getASTContext (), 1 , {0 }),
671
+ AFD->getGenericSignature ()));
672
+
568
673
visitDefaultArguments (AFD, AFD->getParameters ());
569
674
}
570
675
@@ -617,6 +722,15 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) {
617
722
ASD->visitEmittedAccessors ([&](AccessorDecl *accessor) {
618
723
visitFuncDecl (accessor);
619
724
});
725
+
726
+ // Add derivative function symbols.
727
+ for (const auto *differentiableAttr :
728
+ ASD->getAttrs ().getAttributes <DifferentiableAttr>())
729
+ addDerivativeConfiguration (
730
+ ASD->getAccessor (AccessorKind::Get),
731
+ AutoDiffConfig (differentiableAttr->getParameterIndices (),
732
+ IndexSubset::get (ASD->getASTContext (), 1 , {0 }),
733
+ differentiableAttr->getDerivativeGenericSignature ()));
620
734
}
621
735
622
736
void TBDGenVisitor::visitVarDecl (VarDecl *VD) {
0 commit comments