diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index ed364b066e..96597f49fb 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -5,6 +5,7 @@ import textwrap from copy import deepcopy from typing import Any, Dict, KeysView, Literal, NamedTuple, Type +from typing import get_args, Union import json_repair import litellm @@ -43,25 +44,7 @@ def __call__( inputs = self.format(signature, demos, inputs) inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) - try: - provider = lm.model.split("/", 1)[0] or "openai" - params = litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider) - if params and "response_format" in params: - try: - response_format = _get_structured_outputs_response_format(signature) - outputs = lm(**inputs, **lm_kwargs, response_format=response_format) - except Exception as e: - logger.debug( - f"Failed to obtain response using signature-based structured outputs" - f" response format: Falling back to default 'json_object' response format." - f" Exception: {e}" - ) - outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) - else: - outputs = lm(**inputs, **lm_kwargs) - - except litellm.UnsupportedParamsError: - outputs = lm(**inputs, **lm_kwargs) + outputs = lm(**inputs, **lm_kwargs) values = [] @@ -71,7 +54,7 @@ def __call__( signature.output_fields.keys() ), f"Expected {signature.output_fields.keys()} but got {value.keys()}" values.append(value) - + return values def format( diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index ad07a5e631..fb848ee475 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -87,16 +87,42 @@ def find_enum_member(enum, identifier): def parse_value(value, annotation): + origin_annotation = annotation + + # Handle Optional[T] (i.e., Union[T, None]) and validate Union assumptions + if get_origin(annotation) is Union: + args = get_args(annotation) + non_none_args = [arg for arg in args if arg is not type(None)] + + if len(non_none_args) == 1: + annotation = non_none_args[0] + else: + raise TypeError( + f"Unsupported Union type: {annotation}. " + f"Expected Optional[T] (i.e., Union[T, None]), but got Union with multiple concrete types: {non_none_args}" + ) + + # Explicitly return None if the value is None and the annotation allowed it + if value is None: + if get_origin(origin_annotation) is Union and type(None) in get_args(origin_annotation): + return None + else: + raise TypeError(f"Received None for non-optional annotation: {annotation}") + + # Handle str if annotation is str: return str(value) + # Handle Enums if isinstance(annotation, enum.EnumMeta): return find_enum_member(annotation, value) + # Validate if input is already the right type if not isinstance(value, str): return TypeAdapter(annotation).validate_python(value) - candidate = json_repair.loads(value) # json_repair.loads returns "" on failure. + # Try to parse string value + candidate = json_repair.loads(value) if candidate == "" and value != "": try: candidate = ast.literal_eval(value) diff --git a/dspy/teleprompt/__init__.py b/dspy/teleprompt/__init__.py index 3168cd1c44..2fefec2f21 100644 --- a/dspy/teleprompt/__init__.py +++ b/dspy/teleprompt/__init__.py @@ -6,6 +6,7 @@ from dspy.teleprompt.ensemble import Ensemble from dspy.teleprompt.knn_fewshot import KNNFewShot from dspy.teleprompt.simba import SIMBA +from dspy.teleprompt.simba_fast import SIMBAFast from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2 from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch @@ -29,4 +30,5 @@ "LabeledFewShot", "InferRules", "SIMBA", + "SIMBAFast", ] diff --git a/dspy/teleprompt/mipro_optimizer_v2.py b/dspy/teleprompt/mipro_optimizer_v2.py index 0f92ece379..711b227f50 100644 --- a/dspy/teleprompt/mipro_optimizer_v2.py +++ b/dspy/teleprompt/mipro_optimizer_v2.py @@ -3,11 +3,13 @@ import textwrap from collections import defaultdict from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +import select +import sys +import time import numpy as np import optuna from optuna.distributions import CategoricalDistribution - import dspy from dspy.evaluate.evaluate import Evaluate from dspy.propose import GroundedProposer @@ -31,9 +33,9 @@ MIN_MINIBATCH_SIZE = 50 AUTO_RUN_SETTINGS = { - "light": {"num_trials": 7, "val_size": 100}, - "medium": {"num_trials": 25, "val_size": 300}, - "heavy": {"num_trials": 50, "val_size": 1000}, + "light": {"n": 6, "val_size": 100}, + "medium": {"n": 12, "val_size": 300}, + "heavy": {"n": 18, "val_size": 1000}, } # ANSI escape codes for colors @@ -53,9 +55,9 @@ def __init__( teacher_settings: Dict = {}, max_bootstrapped_demos: int = 4, max_labeled_demos: int = 4, - auto: Optional[Literal["light", "medium", "heavy"]] = "medium", - num_candidates: int = 10, - num_threads: int = 6, + auto: Optional[Literal["light", "medium", "heavy"]] = "light", + num_candidates: Optional[int] = None, + num_threads: Optional[int] = None, max_errors: int = 10, seed: int = 9, init_temperature: float = 0.5, @@ -69,7 +71,8 @@ def __init__( if auto not in allowed_modes: raise ValueError(f"Invalid value for auto: {auto}. Must be one of {allowed_modes}.") self.auto = auto - + self.num_fewshot_candidates = num_candidates + self.num_instruct_candidates = num_candidates self.num_candidates = num_candidates self.metric = metric self.init_temperature = init_temperature @@ -96,7 +99,7 @@ def compile( trainset: List, teacher: Any = None, valset: Optional[List] = None, - num_trials: int = 30, + num_trials: Optional[int] = None, max_bootstrapped_demos: Optional[int] = None, max_labeled_demos: Optional[int] = None, seed: Optional[int] = None, @@ -109,8 +112,23 @@ def compile( tip_aware_proposer: bool = True, fewshot_aware_proposer: bool = True, requires_permission_to_run: bool = True, - provide_traceback: bool = False, + provide_traceback: Optional[bool] = None, ) -> Any: + + zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0) + + # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value + if self.auto is None and (self.num_candidates is not None and num_trials is None): + raise ValueError(f"If auto is None, num_trials must also be provided. Given num_candidates={self.num_candidates}, we'd recommend setting num_trials to ~{self._set_num_trials_from_num_candidates(student, zeroshot_opt, self.num_candidates)}.") + + # If auto is None, and num_candidates or num_trials is None, raise an error + if self.auto is None and (self.num_candidates is None or num_trials is None): + raise ValueError("If auto is None, num_candidates must also be provided.") + + # If auto is provided, and either num_candidates or num_trials is not None, raise an error + if self.auto is not None and (self.num_candidates is not None or num_trials is not None): + raise ValueError("If auto is not None, num_candidates and num_trials cannot be set, since they would be overrided by the auto settings. Please either set auto to None, or do not specify num_candidates and num_trials.") + # Set random seeds seed = seed or self.seed self._set_random_seeds(seed) @@ -125,7 +143,6 @@ def compile( trainset, valset = self._set_and_validate_datasets(trainset, valset) # Set hyperparameters based on run mode (if set) - zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0) num_trials, valset, minibatch = self._set_hyperparams_from_run_mode( student, num_trials, minibatch, zeroshot_opt, valset ) @@ -201,6 +218,15 @@ def _set_random_seeds(self, seed): self.rng = random.Random(seed) np.random.seed(seed) + def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candidates): + num_vars = len(program.predictors()) + if not zeroshot_opt: + num_vars *= 2 # Account for few-shot examples + instruction variables + # Trials = MAX(c*M*log(N), c=2, 3/2*N) + num_trials = int(max(2 * num_vars * np.log2(num_candidates), 1.5 * num_candidates)) + + return num_trials + def _set_hyperparams_from_run_mode( self, program: Any, @@ -212,15 +238,18 @@ def _set_hyperparams_from_run_mode( if self.auto is None: return num_trials, valset, minibatch - num_vars = len(program.predictors()) - if not zeroshot_opt: - num_vars *= 2 # Account for few-shot examples + instruction variables - auto_settings = AUTO_RUN_SETTINGS[self.auto] - num_trials = auto_settings["num_trials"] + valset = create_minibatch(valset, batch_size=auto_settings["val_size"], rng=self.rng) minibatch = len(valset) > MIN_MINIBATCH_SIZE - self.num_candidates = int(np.round(np.min([num_trials * num_vars, (1.5 * num_trials) / num_vars]))) + + # Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N + # This is because we've found that it's generally better to spend optimization budget on few-shot examples + # When they are allowed. + self.num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5) + self.num_fewshot_candidates = auto_settings["n"] + + num_trials = self._set_num_trials_from_num_candidates(program, zeroshot_opt, auto_settings["n"]) return num_trials, valset, minibatch @@ -246,7 +275,8 @@ def _print_auto_run_settings(self, num_trials: int, minibatch: bool, valset: Lis f"\nRUNNING WITH THE FOLLOWING {self.auto.upper()} AUTO RUN SETTINGS:" f"\nnum_trials: {num_trials}" f"\nminibatch: {minibatch}" - f"\nnum_candidates: {self.num_candidates}" + f"\nnum_fewshot_candidates: {self.num_fewshot_candidates}" + f"\nnum_instruct_candidates: {self.num_instruct_candidates}" f"\nvalset size: {len(valset)}\n" ) @@ -265,12 +295,12 @@ def _estimate_lm_calls( # Estimate prompt model calls estimated_prompt_model_calls = ( 10 # Data summarizer calls - + self.num_candidates * num_predictors # Candidate generation + + self.num_instruct_candidates * num_predictors # Candidate generation + (num_predictors + 1 if program_aware_proposer else 0) # Program-aware proposer ) prompt_model_line = ( f"{YELLOW}- Prompt Generation: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + " - f"{BLUE}{BOLD}{self.num_candidates}{ENDC}{YELLOW} * " + f"{BLUE}{BOLD}{self.num_instruct_candidates}{ENDC}{YELLOW} * " f"{BLUE}{BOLD}{num_predictors}{ENDC}{YELLOW} lm calls in program " f"+ ({BLUE}{BOLD}{num_predictors + 1}{ENDC}{YELLOW}) lm calls in program-aware proposer " f"= {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}" @@ -342,6 +372,7 @@ def _get_user_confirmation( user_confirmation_message = textwrap.dedent( f"""\ To proceed with the execution of this program, please confirm by typing {BLUE}'y'{ENDC} for yes or {BLUE}'n'{ENDC} for no. + If no input is received within 20 seconds, the program will proceed automatically. If you would like to bypass this confirmation step in future executions, set the {YELLOW}`requires_permission_to_run`{ENDC} flag to {YELLOW}`False`{ENDC} when calling compile. @@ -349,10 +380,18 @@ def _get_user_confirmation( """ ) - user_input = ( - input(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ").strip().lower() - ) - return user_input == "y" + print(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ", end='', flush=True) + + # Wait for input with timeout + start_time = time.time() + while time.time() - start_time < 20: + if select.select([sys.stdin], [], [], 0.1)[0]: + user_input = sys.stdin.readline().strip().lower() + return user_input == "y" + time.sleep(0.1) + + print("\nNo input received within 20 seconds. Proceeding with execution...") + return True def _bootstrap_fewshot_examples(self, program: Any, trainset: List, seed: int, teacher: Any) -> Optional[List]: logger.info("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==") @@ -363,14 +402,14 @@ def _bootstrap_fewshot_examples(self, program: Any, trainset: List, seed: int, t else: logger.info("These will be used for informing instruction proposal.\n") - logger.info(f"Bootstrapping N={self.num_candidates} sets of demonstrations...") + logger.info(f"Bootstrapping N={self.num_fewshot_candidates} sets of demonstrations...") zeroshot = self.max_bootstrapped_demos == 0 and self.max_labeled_demos == 0 try: demo_candidates = create_n_fewshot_demo_sets( student=program, - num_candidate_sets=self.num_candidates, + num_candidate_sets=self.num_fewshot_candidates, trainset=trainset, max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_labeled_demos), max_bootstrapped_demos=( @@ -424,12 +463,12 @@ def _propose_instructions( rng=self.rng, ) - logger.info("\nProposing instructions...\n") + logger.info(f"\nProposing N={self.num_instruct_candidates} instructions...\n") instruction_candidates = proposer.propose_instructions_for_program( trainset=trainset, program=program, demo_candidates=demo_candidates, - N=self.num_candidates, + N=self.num_instruct_candidates, T=self.init_temperature, trial_logs={}, ) @@ -465,7 +504,7 @@ def _optimize_prompt_parameters( # Compute the adjusted total trials that we will run (including full evals) run_additional_full_eval_at_end = 1 if num_trials % minibatch_full_eval_steps != 0 else 0 - adjusted_num_trials = (num_trials + num_trials // minibatch_full_eval_steps + 1 + run_additional_full_eval_at_end) if minibatch else num_trials + adjusted_num_trials = int((num_trials + num_trials // minibatch_full_eval_steps + 1 + run_additional_full_eval_at_end) if minibatch else num_trials) logger.info(f"== Trial {1} / {adjusted_num_trials} - Full Evaluation of Default Program ==") default_score, _ = eval_candidate_program( diff --git a/dspy/teleprompt/simba.py b/dspy/teleprompt/simba.py index 8554d0b1e4..1452558667 100644 --- a/dspy/teleprompt/simba.py +++ b/dspy/teleprompt/simba.py @@ -6,6 +6,7 @@ from typing import Callable from dspy.teleprompt.teleprompt import Teleprompter from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule +from dspy.teleprompt.utils import log_token_usage logger = logging.getLogger(__name__) @@ -117,6 +118,14 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): rng.shuffle(data_indices) instance_idx = 0 + M = self.max_steps - 1 + N = self.num_candidates + 1 + program_idxs = [0] * N if M < 1 else [round(i * M / (N - 1)) for i in range(N)] + program_idxs = list(dict.fromkeys(program_idxs)) + + final_candidate_programs = [] + final_candidate_scores = [] + # Parallel runner run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads) @@ -156,12 +165,17 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): # Use the special wrap that includes the 'example' in the output wrapped_candidate_system = wrap_program(candidate_system, self.metric) exec_pairs.append((wrapped_candidate_system, example)) + + # TODO: check to see if token count from all these models is accounted for in lm.history + # Could add their history to main dspy.settings.lm.history # STEP 2: Execute logger.info(f"Sampling program trajectories on {self.bsize} examples x {self.num_candidates} samples.") outputs = run_parallel(exec_pairs) assert len(outputs) == len(exec_pairs) == self.bsize * self.num_candidates + dspy.settings.lm.history.extend([entry for model in models for entry in model.history]) + # STEP 3: Sort the training buckets by (max-to-min gap, max score, and max-to-avg gap). buckets = [] largest_max_to_avg_gap = float("-inf") @@ -285,40 +299,30 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): sys_scores = [outputs[i]["score"] for i in range(start, end)] register_new_program(cand_sys, sys_scores) - M = len(winning_programs) - 1 - N = self.num_candidates + 1 - if M < 1: - # Only one or zero winning programs - program_idxs = [0] * N - else: - program_idxs = [round(i * M / (N - 1)) for i in range(N)] - program_idxs = list(dict.fromkeys(program_idxs)) - - candidate_programs = [winning_programs[i].deepcopy() for i in program_idxs] - logger.info(f"VALIDATION: Evaluating {len(candidate_programs)} programs on the full trainset.") - exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in candidate_programs for ex in trainset] - outputs = run_parallel(exec_pairs) - - scores = [] - for idx_prog, prog in enumerate(candidate_programs): - start = idx_prog * len(trainset) - end = (idx_prog + 1) * len(trainset) - sys_scores = [outputs[i]["score"] for i in range(start, end)] - avg_score = sum(sys_scores) / len(sys_scores) if sys_scores else 0.0 - scores.append(avg_score) - if idx_prog != 0: - trial_logs[idx_prog-1]["train_score"] = avg_score - - best_idx = scores.index(max(scores)) if scores else 0 - best_program = candidate_programs[best_idx] + # STEP 8: If it's time for a full evaluation, evaluate the winning program on the full trainset + if batch_idx in program_idxs: + logger.info(f"Batch {batch_idx+1}: Evaluating winning program on full trainset.") + exec_pairs = [(wrap_program(best_program, self.metric), ex) for ex in trainset] + full_outputs = run_parallel(exec_pairs) + scores = [o["score"] for o in full_outputs] + avg_score = sum(scores) / len(scores) + trial_logs[batch_idx]["train_score"] = avg_score + + final_candidate_programs.append(best_program.deepcopy()) + final_candidate_scores.append(avg_score) + + log_token_usage(trial_logs, batch_idx, {"lm": dspy.settings.lm}) + + best_idx = np.argmax(final_candidate_scores) if final_candidate_scores else 0 + # best_idx = scores.index(max(final_candidate_scores)) if final_candidate_scores else 0 + best_program = final_candidate_programs[best_idx] logger.info( - f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} " - f"(at index {best_idx if scores else 'N/A'})\n\n\n" + f"Final trainset scores: {final_candidate_scores}, Best: {max(final_candidate_scores) if final_candidate_scores else 'N/A'} " + f"(at index {best_idx if final_candidate_scores else 'N/A'})\n\n\n" ) - # FIXME: Attach all program candidates in decreasing average score to the best program. - best_program.candidate_programs = candidate_programs + best_program.candidate_programs = final_candidate_programs best_program.winning_programs = winning_programs best_program.trial_logs = trial_logs - return best_program + return best_program \ No newline at end of file diff --git a/dspy/teleprompt/simba_fast.py b/dspy/teleprompt/simba_fast.py new file mode 100644 index 0000000000..425411da3e --- /dev/null +++ b/dspy/teleprompt/simba_fast.py @@ -0,0 +1,396 @@ +import dspy +import random +import logging + +import numpy as np +from typing import Callable, Optional, Any, Dict +from dspy.teleprompt.teleprompt import Teleprompter +from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule +from dspy.teleprompt.utils import log_token_usage + +logger = logging.getLogger(__name__) + + +# Stochastic Introspective Mini-Batch Ascent +class SIMBAFast(Teleprompter): + def __init__( + self, + *, + metric: Callable, + bsize=32, + num_candidates=6, + max_steps=8, + max_demos=4, + prompt_model: Optional[Any] = None, + teacher_settings: Optional[Dict] = None, + demo_input_field_maxlen=100_000, + num_threads=16, + temperature_for_sampling=0.2, + temperature_for_candidates=0.2, + ): + """ + :param metric: A function (Example, prediction_dict) -> float + :param bsize: mini-batch size + :param num_candidates: how many new candidate programs to produce per iteration + :param max_steps: how many optimization steps to run + :param max_demos: how many demos we allow a predictor to hold before we must drop some + :param demo_input_field_maxlen: how many characters of an input field to keep when building a new demo + :param num_threads: how many threads for run_parallel + :param temperature_for_sampling: temperature used for picking programs for the trajectory-sampling step + :param temperature_for_candidates: temperature used for picking the source program for building new candidates + """ + self.metric = metric + self.bsize = bsize + self.num_candidates = num_candidates + self.max_steps = max_steps + self.max_demos = max_demos + self.prompt_model = prompt_model if prompt_model else dspy.settings.lm + self.teacher_settings = teacher_settings if teacher_settings else {} + self.demo_input_field_maxlen = demo_input_field_maxlen + self.num_threads = num_threads + + self.temperature_for_sampling = temperature_for_sampling + self.temperature_for_candidates = temperature_for_candidates + + if self.max_demos > 0: + self.strategies = [append_a_demo(demo_input_field_maxlen), append_a_rule] + else: + self.strategies = [append_a_rule] + + def compile(self, student: dspy.Module, *, trainset: list[dspy.Example], seed: int = 0): + # Basic checks + assert len(trainset) >= self.bsize, f"Trainset too small: {len(trainset)} < {self.bsize}" + + # Initialize RNG + rng = random.Random(seed) + rng_np = np.random.default_rng(seed) + + programs = [] + program_scores = {} + program_batch_idx = {} + next_program_idx = 0 + batch_idx_to_baseline_scores = {} + + # Helper functions + def calc_average_score(prog_idx: int) -> float: + scores = program_scores.get(prog_idx, []) + if not scores: + return 0.0 + return sum(scores) / len(scores) + + def calc_average_adjusted_score(prog_idx: int) -> float: + prog_scores = program_scores.get(prog_idx, []) + baseline_scores = batch_idx_to_baseline_scores.get(program_batch_idx[prog_idx], []) + + # If either list is empty or not the same length, return 0 or handle how you prefer + if not prog_scores or not baseline_scores: + return 0.0 + if len(prog_scores) != len(baseline_scores): + # You can decide how you want to handle mismatch + return 0.0 + + # Elementwise subtraction + adjusted_scores = [p - b for p, b in zip(prog_scores, baseline_scores)] + return sum(adjusted_scores) / len(adjusted_scores) + + def adjusted_top_k_plus_baseline(k: int) -> list[int]: + # Sort all programs by descending average score + scored_programs = sorted(programs, key=lambda p: calc_average_adjusted_score(p.simba_idx), reverse=True) + top_k = [p.simba_idx for p in scored_programs[:k]] + # Ensure baseline=0 is in there: + if 0 not in top_k and len(top_k) > 0: + top_k[-1] = 0 + return list(dict.fromkeys(top_k)) + + def top_k_plus_baseline(k: int) -> list[int]: + # Sort all programs by descending average score + scored_programs = sorted(programs, key=lambda p: calc_average_score(p.simba_idx), reverse=True) + top_k = [p.simba_idx for p in scored_programs[:k]] + # Ensure baseline=0 is in there: + if 0 not in top_k and len(top_k) > 0: + top_k[-1] = 0 + return list(dict.fromkeys(top_k)) + + def softmax_sample(rng_obj: random.Random, program_idxs: list[int], temperature: float) -> int: + if not program_idxs: + raise ValueError("No programs available for softmax sampling.") + + # Unnormalized weights + scores = [calc_average_score(idx) for idx in program_idxs] + exps = [np.exp(s / temperature) for s in scores] + sum_exps = sum(exps) + if sum_exps <= 0: + # Fallback: uniform if all exps are zero + return rng_obj.choice(program_idxs) + + # Weighted random choice + probs = [val / sum_exps for val in exps] + return rng_obj.choices(program_idxs, weights=probs, k=1)[0] + + def register_new_program(prog: dspy.Module, score_list: list[float], batch_idx: int): + nonlocal next_program_idx + next_program_idx += 1 + new_idx = next_program_idx + prog.simba_idx = new_idx + programs.append(prog) + program_scores[new_idx] = score_list + program_batch_idx[new_idx] = batch_idx + + # Initialize the baseline program: index=0 + student = student.deepcopy() + student.simba_idx = 0 + programs.append(student) + program_scores[0] = [] + program_batch_idx[0] = 0 + + winning_programs = [(0,student)] + + # Data shuffling + data_indices = list(range(len(trainset))) + rng.shuffle(data_indices) + instance_idx = 0 + + # Parallel runner + logger.info(f"Creating parallel runner with num_threads: {self.num_threads}") + run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads) + + trial_logs = {} + + # Initialize for hybrid execution reuse + last_batch_outputs = None + + predictor2name = {} + + M = self.max_steps - 1 + N = self.num_candidates + 1 + program_idxs = [0] * N if M < 1 else [round(i * M / (N - 1)) for i in range(N)] + program_idxs = list(dict.fromkeys(program_idxs)) + + # Compute baseline student score on the full trainset + logger.info(f"Evaluating student program on full trainset.") + exec_pairs = [(wrap_program(student, self.metric), ex) for ex in trainset] + full_outputs = run_parallel(exec_pairs) + baseline_scores = [o["score"] for o in full_outputs] + + # Compute average score for the baseline program + avg_baseline_score = sum(baseline_scores) / len(baseline_scores) + logger.info(f"Baseline program (index 0) score: {avg_baseline_score}\n") + + final_candidate_programs = [student] + final_candidate_scores = [avg_baseline_score] + validated_program_outputs = {} # {prog_idx: {example_idx: output_dict}} + + for batch_idx in range(self.max_steps): + trial_logs[batch_idx+1] = {} + + logger.info(f"Starting batch {batch_idx+1} of {self.max_steps}.") + + # STEP 1: Get next batch + if instance_idx + self.bsize > len(trainset): + rng.shuffle(data_indices) + instance_idx = 0 + + batch_indices = data_indices[instance_idx : instance_idx + self.bsize] + batch = [trainset[i] for i in batch_indices] + instance_idx += self.bsize + + # Compute student baseline on batch + batch_idx_to_baseline_scores[batch_idx] = [score for i, score in enumerate(baseline_scores) if i in batch_indices] + + # STEP 2 (or hybrid): Collect execution results for bucket building + models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings) + top_programs = top_k_plus_baseline(self.num_candidates) + + exec_pairs = [] + + if batch_idx == 0: + # First round — use full trajectory sampling + for model in models: + for example in batch: + chosen_prog_idx = softmax_sample(rng, top_programs, self.temperature_for_sampling) + candidate_system = programs[chosen_prog_idx].deepcopy() + candidate_system.set_lm(model) + + for name, predictor in candidate_system.named_predictors(): + predictor2name[id(predictor)] = name + + wrapped_candidate_system = wrap_program(candidate_system, self.metric) + exec_pairs.append((wrapped_candidate_system, example)) + + logger.info(f"Sampling program trajectories on {self.bsize} examples x {self.num_candidates} samples.") + outputs = run_parallel(exec_pairs) + else: + outputs = last_batch_outputs.copy() if last_batch_outputs else [] + for prog_idx, prog_cache in validated_program_outputs.items(): + for i in batch_indices: + if i in prog_cache: + outputs.append(prog_cache[i]) + + dspy.settings.lm.history.extend([entry for model in models for entry in model.history]) + + # STEP 3: Sort the training buckets by (max-to-min gap, max score, and max-to-avg gap). + buckets = [] + largest_max_to_avg_gap = float("-inf") + batch_10th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 10) + batch_90th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 90) + + # We'll chunk `outputs` by example index, each chunk has length = num_candidates + for idx, example in enumerate(batch): + # gather all results for this example + bucket = [outputs[i] for i in range(idx, len(outputs), self.bsize)] + bucket.sort(key=lambda x: x["score"], reverse=True) + + max_score = float(bucket[0]["score"]) + min_score = float(bucket[-1]["score"]) + avg_score = sum(x["score"] for x in bucket) / len(bucket) + max_to_min_gap = max_score - min_score + max_to_avg_gap = max_score - avg_score + if max_to_avg_gap > largest_max_to_avg_gap: + largest_max_to_avg_gap = max_to_avg_gap + + buckets.append((bucket, (max_to_min_gap, max_score, max_to_avg_gap))) + + # sort the buckets + buckets.sort(key=lambda x: x[1], reverse=True) + # TODO: if all buckets mave a max_to_min gap of 0 and max score <1.0, then we should do more trajectory sampling + + # Baseline for the batch is just the average of all runs + all_scores_in_this_batch = [o["score"] for o in outputs] + baseline_score = sum(all_scores_in_this_batch) / len(all_scores_in_this_batch) + logger.info(f"Batch {batch_idx+1}: Baseline mini-batch score: {baseline_score}\n") + + # summarize_batch([bucket[0] for bucket in buckets]) + # STEP 4: Build new candidate programs by applying a strategy to some top buckets. + system_candidates = [] + for bucket_idx, (bucket, bucket_stats) in enumerate(buckets): + max_to_min_gap, max_score, max_to_avg_gap = bucket_stats + logger.info( + f"Batch {batch_idx+1}: Processing bucket #{bucket_idx+1}, with max score {max_score}, " + f"max-to-min gap {max_to_min_gap}, and max-to-avg gap {max_to_avg_gap}." + ) + + # pick source program + src_prog_idx = softmax_sample( + rng, top_k_plus_baseline(self.num_candidates), self.temperature_for_candidates + ) + system_candidate = programs[src_prog_idx].deepcopy() + + # Drop some demos from each predictor + name2predictor = {} + num_demos_list = [] + + max_demos_tmp = self.max_demos if self.max_demos > 0 else 3 + + for name, predictor in system_candidate.named_predictors(): + name2predictor[name] = predictor + num_demos_list.append(len(predictor.demos)) + + num_demos = max(num_demos_list) if num_demos_list else 0 + num_demos_to_drop = max(rng_np.poisson(num_demos / max_demos_tmp), int(num_demos >= max_demos_tmp)) + num_demos_to_drop = min(num_demos_to_drop, num_demos) + demos_to_drop = [rng.randrange(num_demos) for _ in range(num_demos_to_drop)] + + for name, predictor in name2predictor.items(): + predictor.demos = [demo for idxd, demo in enumerate(predictor.demos) if idxd not in demos_to_drop] + + # Pick a strategy + strategy = rng.choice(self.strategies) + logger.info( + f"Batch {batch_idx+1}: Invoking strategy: {strategy.__name__}" + + (f", having dropped {num_demos_to_drop} demos per predictor" if num_demos_to_drop else "") + ) + + for name, predictor in system_candidate.named_predictors(): + predictor2name[id(predictor)] = name + + try: + strategy( + bucket, + system_candidate, + predictor2name=predictor2name, + name2predictor=name2predictor, + batch_10p_score=batch_10th_percentile_score, + batch_90p_score=batch_90th_percentile_score, + prompt_model=self.prompt_model, + ) + except Exception as e: + logger.error(f"Strategy failed with error: {e}") + continue + + system_candidates.append(system_candidate) + logger.info("\n") + + if len(system_candidates) >= self.num_candidates: + break + + # STEP 5: Evaluate these new system_candidates on the same mini-batch + logger.info(f"Batch {batch_idx+1}: Evaluating {len(system_candidates)} programs on {self.bsize} examples.") + + exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in system_candidates for ex in batch] + outputs = run_parallel(exec_pairs) + assert len(outputs) == len(exec_pairs) == len(system_candidates) * self.bsize + + # STEP 6: Compute average mini-batch scores for each new candidate + candidate_scores = [] + for idx_cand, cand_sys in enumerate(system_candidates): + start = idx_cand * self.bsize + end = (idx_cand + 1) * self.bsize + sys_scores = [outputs[i]["score"] for i in range(start, end)] + avg_sys_score = sum(sys_scores) / len(sys_scores) + candidate_scores.append(avg_sys_score) + + logger.info( + f"Scores after {batch_idx+1} batches: {candidate_scores}, " + f"Best: {max(candidate_scores) if candidate_scores else 'N/A'}\n" + ) + + trial_logs[batch_idx+1]["batch_scores"] = candidate_scores + + # STEP 7: Select the best among these new ones for "winning" record + if candidate_scores: + best_idx_among_candidates = candidate_scores.index(max(candidate_scores)) + best_program = system_candidates[best_idx_among_candidates] + winning_programs.append((batch_idx+1, best_program.deepcopy())) + + # STEP 8: If it's time for a full evaluation, evaluate the winning program on the full trainset + if batch_idx in program_idxs: + logger.info(f"Batch {batch_idx+1}: Evaluating winning program on full trainset.") + exec_pairs = [(wrap_program(best_program, self.metric), ex) for ex in trainset] + full_outputs = run_parallel(exec_pairs) + scores = [o["score"] for o in full_outputs] + avg_score = sum(scores) / len(scores) + logger.info(f"Batch {batch_idx+1}: Full trainset score: {avg_score}") + trial_logs[batch_idx + 1]["train_score"] = avg_score + + final_candidate_programs.append(best_program.deepcopy()) + final_candidate_scores.append(avg_score) + + prog_cache = {i: out for i, out in enumerate(full_outputs)} + validated_program_outputs[best_program.simba_idx] = prog_cache + + # STEP 9: Register all new candidate systems in our global pool + for idx_cand, cand_sys in enumerate(system_candidates): + start = idx_cand * self.bsize + end = (idx_cand + 1) * self.bsize + sys_scores = [outputs[i]["score"] for i in range(start, end)] + register_new_program(cand_sys, sys_scores, batch_idx) + + # Save for hybrid bucket building next round + last_batch_outputs = outputs.copy() + + log_token_usage(trial_logs, batch_idx+1, {"lm": dspy.settings.lm}) + + + best_idx = np.argmax(final_candidate_scores) if final_candidate_scores else 0 + # best_idx = scores.index(max(final_candidate_scores)) if final_candidate_scores else 0 + best_program = final_candidate_programs[best_idx] + logger.info( + f"Final trainset scores: {final_candidate_scores}, Best: {max(final_candidate_scores) if final_candidate_scores else 'N/A'} " + f"(at index {best_idx if final_candidate_scores else 'N/A'})\n\n\n" + ) + # FIXME: Attach all program candidates in decreasing average score to the best program. + best_program.candidate_programs = final_candidate_programs + best_program.winning_programs = winning_programs + best_program.trial_logs = trial_logs + + return best_program diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index 0aea7b5d33..32438164bb 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -3,20 +3,40 @@ import inspect import logging import textwrap +import re from dspy.adapters.chat_adapter import enumerate_fields from dspy.signatures import InputField, OutputField -from typing import Callable +from typing import Callable, Optional, Dict, Any logger = logging.getLogger(__name__) +def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None): + + models = [] + if teacher_settings: + with dspy.settings.context(trace=[], **teacher_settings): + lm = dspy.settings.lm + models.append(lm) -def prepare_models_for_resampling(program: dspy.Module, n: int): lm = program.get_lm() or dspy.settings.lm - temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / n) for i in range(n)] - temps = list(dict.fromkeys(temps))[:n] - return [lm.copy(temperature=t) for t in temps] + # Check to see if our model is a reasoning model, which means temp must stay as 1.0 + model_family = lm.model.split("/")[-1].lower() if "/" in lm.model else lm.model.lower() + model_pattern = re.match(r"^o([13])(?:-mini)?", model_family) + + if model_pattern: # Vary the seed + start_seed = 0 if "seed" not in lm.kwargs else lm.kwargs["seed"] + seeds = [start_seed + 1 + i for i in range(n-len(models))] + seeds = list(dict.fromkeys(seeds))[:(n-len(models))] + models.extend([lm.copy(seed=seed) for seed in seeds]) + else: # Vary the temperature + start_temp = 0 if "temperature" not in lm.kwargs else lm.kwargs["temperature"] + temps = [start_temp + 0.5 + i * (0.5 / n) for i in range(n-len(models))] + temps = list(dict.fromkeys(temps))[:(n-len(models))] + models.extend([lm.copy(temperature=t) for t in temps]) + + return models def wrap_program(program: dspy.Module, metric: Callable): def wrapped_program(example): @@ -25,33 +45,56 @@ def wrapped_program(example): try: prediction = program(**example.inputs()) except Exception as e: - print(e) + logger.info(e) trace = dspy.settings.trace.copy() + output = None + score = 0.0 + output_metadata = {} + try: - score = metric(example, prediction) + output = metric(example, prediction) + if isinstance(output, (int, float)): + score = output + elif isinstance(output, dspy.Prediction): + if not hasattr(output, 'score'): + raise ValueError("dspy.Prediction must contain a 'score' attribute") + score = output.score + # Just extract fields from _store, excluding 'score' + output_metadata = { + k: v for k, v in output._store.items() if k != "score" + } except Exception as e: - print(e) + logger.info(e) - # Include the `example` in the output for subsequent usage in buckets/strategies. return { "prediction": prediction, "trace": trace, "score": score, - "example": example + "example": example, + "output_metadata": output_metadata } return wrapped_program - - def append_a_demo(demo_input_field_maxlen): def append_a_demo_(bucket, system, **kwargs): predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] + batch_10p_score = kwargs["batch_10p_score"] - trace = bucket[0]["trace"] + logger.info(f"Appending a demo with max length {demo_input_field_maxlen}") + + good = bucket[0] + trace = good["trace"] name2demo = {} + # if good["score"] < batch_10p_score: + # logger.info(f"Skipping appending a demo as good score {good['score']} is below the 10th percentile.") + # return False + if good["score"] <= batch_10p_score: + logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.") + return False + for step in trace: predictor, _inputs, _outputs = step @@ -62,28 +105,29 @@ def append_a_demo_(bucket, system, **kwargs): demo = dspy.Example(augmented=True, **_inputs, **_outputs) name = predictor2name[id(predictor)] name2demo[name] = demo # keep the last demo for each predictor - for name, demo in name2demo.items(): predictor = name2predictor[name] predictor.demos.append(demo) - logger.info(f"Added {len(name2demo)} demos (one each) across all predictors.") + logger.info(f"Added {len(name2demo)} demos (one each) across all predictors. Each predictor now has {len(predictor.demos)} demos total.") return True return append_a_demo_ def append_a_rule(bucket, system, **kwargs): + # Read in kwargs predictor2name = kwargs["predictor2name"] batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] + prompt_model = kwargs["prompt_model"] or dspy.settings.lm module_names = [name for name, _ in system.named_predictors()] good, bad = bucket[0], bucket[-1] example = good["example"] - if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: - logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " - f"*or* bad score {bad['score']} is above the 90th percentile.") + if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: + logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile " + f"*or* bad score {bad['score']} is at or above the 90th percentile.") return False if good["score"] <= bad["score"]: @@ -116,12 +160,17 @@ def append_a_rule(bucket, system, **kwargs): worse_program_outputs=dict(bad["prediction"] or {}), worse_reward_value=bad["score"], better_reward_value=good["score"], + worse_reward_info=bad["output_metadata"], + better_reward_info=good["output_metadata"], module_names=module_names, ) kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2) for k, v in kwargs.items()} - advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice + + with dspy.settings.context(trace=[], lm=prompt_model): + advice_program = dspy.Predict(OfferFeedback) + advice = advice_program(**kwargs).module_advice for name, predictor in system.named_predictors(): if name in advice: @@ -155,11 +204,13 @@ class OfferFeedback(dspy.Signature): ) worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") better_program_trajectory: str = InputField( desc="The trajectory of the program's execution, showing each module's I/O" ) better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") module_advice: dict[str, str] = OutputField( @@ -169,7 +220,6 @@ class OfferFeedback(dspy.Signature): "like the successful trajectory rather than the lower-scoring trajectory." ) - def inspect_modules(program): separator = "-" * 80 output = [separator] @@ -209,4 +259,4 @@ def recursive_mask(o): return tuple(recursive_mask(v) for v in o) # Otherwise, replace it with a placeholder string (or use repr(o)). else: - return f"" + return f"" \ No newline at end of file diff --git a/dspy/teleprompt/utils.py b/dspy/teleprompt/utils.py index 954ce5604a..e0e10cc127 100644 --- a/dspy/teleprompt/utils.py +++ b/dspy/teleprompt/utils.py @@ -1,10 +1,12 @@ import inspect +import inspect import logging import math import os import random import shutil import sys +from typing import Tuple import numpy as np try: @@ -133,6 +135,60 @@ def get_program_with_highest_avg_score(param_score_dict, fully_evaled_param_comb # If no valid program is found, we return the last valid one that we found return program, mean, key, params +def get_token_usage(model) -> Tuple[int, int]: + """ + Extract total input tokens and output tokens from a model's interaction history. + Returns (total_input_tokens, total_output_tokens). + """ + if not hasattr(model, "history"): + return 0, 0 + + input_tokens = [] + output_tokens = [] + for interaction in model.history: + usage = interaction.get("usage", {}) + _input_tokens = usage.get("prompt_tokens", 0) + _output_tokens = usage.get("completion_tokens", 0) + input_tokens.append(_input_tokens) + output_tokens.append(_output_tokens) + + total_input_tokens = np.sum(input_tokens) + total_output_tokens = np.sum(output_tokens) + + return total_input_tokens, total_output_tokens + +def extract_token_usage(model): + """Return (total_input_tokens, total_output_tokens) by summing usage in model.history.""" + if not model or not hasattr(model, "history"): + # If model is None or doesn't have a .history, return 0 usage. + return 0, 0 + + input_tokens = [] + output_tokens = [] + for interaction in model.history: + usage = interaction.get("usage", {}) + _input_tokens = usage.get("prompt_tokens", 0) + _output_tokens = usage.get("completion_tokens", 0) + input_tokens.append(_input_tokens) + output_tokens.append(_output_tokens) + return int(np.sum(input_tokens)), int(np.sum(output_tokens)) + +def log_token_usage(trial_logs, trial_num, model_dict): + """ + Extract total input and output tokens used by each model and log to trial_logs[trial_num]["token_usage"]. + """ + + token_usage_dict = {} + + for model_name, model in model_dict.items(): + in_tokens, out_tokens = extract_token_usage(model) + token_usage_dict[model_name] = { + "total_input_tokens": in_tokens, + "total_output_tokens": out_tokens + } + + # Store token usage info in trial logs + trial_logs[trial_num]["token_usage"] = token_usage_dict def calculate_last_n_proposed_quality( base_program, trial_logs, evaluate, trainset, devset, n, diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index 404aab1dce..1607db0477 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -209,4 +209,4 @@ def _update_progress(self, pbar, nresults, ntotal): pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({pct}%)") else: pbar.set_description(f"Processed {nresults} / {ntotal} examples") - pbar.update() + pbar.update() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4fb32fd719..09c9ba472c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ "backoff>=2.2", "joblib~=1.3", - "openai>=0.28.1,<=1.61.0", + "openai>=0.28.1", "pandas>=2.1.1", "regex>=2023.10.3", "ujson>=5.8.0", @@ -103,7 +103,7 @@ python = ">=3.9,<3.13" pydantic = "^2.0" backoff = "^2.2" joblib = "^1.3" -openai = ">=0.28.1,<=1.61.0" +openai = ">=0.28.1" pandas = "^2.1.1" regex = "^2023.10.3" ujson = "^5.8.0"