Skip to content

[Bug Report] Prioritize Local hf_model.config for Qwen Models to Avoid Unnecessary Hugging Face API Calls #846

Open
@yhr-code

Description

@yhr-code

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
When using Qwen series models (and potentially other models) with transformer_lens, even if the model is already loaded locally, the framework still attempts to fetch hf_config from Hugging Face's online repository. This behavior can cause issues in environments where internet access is restricted or unavailable, despite the model files being fully available locally.

tokenizer_test = AutoTokenizer.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True)
hf_model_test = AutoModelForCausalLM.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True).to(device)
hooked_model = HookedTransformer.from_pretrained_no_processing(
        model_name="Qwen/Qwen2.5-0.5B-Instruct",
        tokenizer=tokenizer_test,
        hf_model=hf_model_test,
        device=device,
        dtype=torch.bfloat16,
    )


---------------------------------------------------------------------------
TimeoutError                              Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connection.py:198, in HTTPConnection._new_conn(self)
    197 try:
--> 198     sock = connection.create_connection(
    199         (self._dns_host, self.port),
    200         self.timeout,
    201         source_address=self.source_address,
    202         socket_options=self.socket_options,
    203     )
    204 except socket.gaierror as e:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/util/connection.py:85, in create_connection(address, timeout, source_address, socket_options)
     84 try:
---> 85     raise err
     86 finally:
     87     # Break explicitly a reference cycle

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/util/connection.py:73, in create_connection(address, timeout, source_address, socket_options)
     72     sock.bind(source_address)
---> 73 sock.connect(sa)
     74 # Break explicitly a reference cycle

TimeoutError: timed out

The above exception was the direct cause of the following exception:

ConnectTimeoutError                       Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:787, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
    786 # Make the request on the HTTPConnection object
--> 787 response = self._make_request(
    788     conn,
    789     method,
    790     url,
    791     timeout=timeout_obj,
    792     body=body,
    793     headers=headers,
    794     chunked=chunked,
    795     retries=retries,
    796     response_conn=response_conn,
    797     preload_content=preload_content,
    798     decode_content=decode_content,
    799     **response_kw,
    800 )
    802 # Everything went great!

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:488, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
    487         new_e = _wrap_proxy_error(new_e, conn.proxy.scheme)
--> 488     raise new_e
    490 # conn.request() calls http.client.*.request, not the method in
    491 # urllib3.request. It also calls makefile (recv) on the socket.

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:464, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
    463 try:
--> 464     self._validate_conn(conn)
    465 except (SocketTimeout, BaseSSLError) as e:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:1093, in HTTPSConnectionPool._validate_conn(self, conn)
   1092 if conn.is_closed:
-> 1093     conn.connect()
   1095 # TODO revise this, see https://github.com/urllib3/urllib3/issues/2791

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connection.py:704, in HTTPSConnection.connect(self)
    703 sock: socket.socket | ssl.SSLSocket
--> 704 self.sock = sock = self._new_conn()
    705 server_hostname: str = self.host

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connection.py:207, in HTTPConnection._new_conn(self)
    206 except SocketTimeout as e:
--> 207     raise ConnectTimeoutError(
    208         self,
    209         f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
    210     ) from e
    212 except OSError as e:

ConnectTimeoutError: (<urllib3.connection.HTTPSConnection object at 0x7ded1df1ecd0>, 'Connection to huggingface.co timed out. (connect timeout=10)')

The above exception was the direct cause of the following exception:

MaxRetryError                             Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/adapters.py:667, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
    666 try:
--> 667     resp = conn.urlopen(
    668         method=request.method,
    669         url=url,
    670         body=request.body,
    671         headers=request.headers,
    672         redirect=False,
    673         assert_same_host=False,
    674         preload_content=False,
    675         decode_content=False,
    676         retries=self.max_retries,
    677         timeout=timeout,
    678         chunked=chunked,
    679     )
    681 except (ProtocolError, OSError) as err:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:841, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
    839     new_e = ProtocolError("Connection aborted.", new_e)
--> 841 retries = retries.increment(
    842     method, url, error=new_e, _pool=self, _stacktrace=sys.exc_info()[2]
    843 )
    844 retries.sleep()

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/util/retry.py:519, in Retry.increment(self, method, url, response, error, _pool, _stacktrace)
    518     reason = error or ResponseError(cause)
--> 519     raise MaxRetryError(_pool, url, reason) from reason  # type: ignore[arg-type]
    521 log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)

MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /Qwen/Qwen2.5-0.5B-Instruct/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7ded1df1ecd0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))

During handling of the above exception, another exception occurred:

ConnectTimeout                            Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:1374, in _get_metadata_or_catch_error(repo_id, filename, repo_type, revision, endpoint, proxies, etag_timeout, headers, token, local_files_only, relative_filename, storage_folder)
   1373 try:
-> 1374     metadata = get_hf_file_metadata(
   1375         url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token
   1376     )
   1377 except EntryNotFoundError as http_error:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    112     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:1294, in get_hf_file_metadata(url, token, proxies, timeout, library_name, library_version, user_agent, headers)
   1293 # Retrieve metadata
-> 1294 r = _request_wrapper(
   1295     method="HEAD",
   1296     url=url,
   1297     headers=hf_headers,
   1298     allow_redirects=False,
   1299     follow_relative_redirects=True,
   1300     proxies=proxies,
   1301     timeout=timeout,
   1302 )
   1303 hf_raise_for_status(r)

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:278, in _request_wrapper(method, url, follow_relative_redirects, **params)
    277 if follow_relative_redirects:
--> 278     response = _request_wrapper(
    279         method=method,
    280         url=url,
    281         follow_relative_redirects=False,
    282         **params,
    283     )
    285     # If redirection, we redirect only relative paths.
    286     # This is useful in case of a renamed repository.

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:301, in _request_wrapper(method, url, follow_relative_redirects, **params)
    300 # Perform request and return if status_code is not in the retry list.
--> 301 response = get_session().request(method=method, url=url, **params)
    302 hf_raise_for_status(response)

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
    588 send_kwargs.update(settings)
--> 589 resp = self.send(prep, **send_kwargs)
    591 return resp

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs)
    702 # Send the request
--> 703 r = adapter.send(request, **kwargs)
    705 # Total elapsed time of the request (approximately)

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:93, in UniqueRequestIdAdapter.send(self, request, *args, **kwargs)
     92 try:
---> 93     return super().send(request, *args, **kwargs)
     94 except requests.RequestException as e:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/adapters.py:688, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
    687     if not isinstance(e.reason, NewConnectionError):
--> 688         raise ConnectTimeout(e, request=request)
    690 if isinstance(e.reason, ResponseError):

ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /Qwen/Qwen2.5-0.5B-Instruct/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7ded1df1ecd0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 4b843dda-131e-44bd-936c-cc92ab8e626e)')

The above exception was the direct cause of the following exception:

LocalEntryNotFoundError                   Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/utils/hub.py:403, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    401 try:
    402     # Load from URL or cache if already cached
--> 403     resolved_file = hf_hub_download(
    404         path_or_repo_id,
    405         filename,
    406         subfolder=None if len(subfolder) == 0 else subfolder,
    407         repo_type=repo_type,
    408         revision=revision,
    409         cache_dir=cache_dir,
    410         user_agent=user_agent,
    411         force_download=force_download,
    412         proxies=proxies,
    413         resume_download=resume_download,
    414         token=token,
    415         local_files_only=local_files_only,
    416     )
    417 except GatedRepoError as e:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    112     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:860, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, proxies, etag_timeout, token, local_files_only, headers, endpoint, resume_download, force_filename, local_dir_use_symlinks)
    859 else:
--> 860     return _hf_hub_download_to_cache_dir(
    861         # Destination
    862         cache_dir=cache_dir,
    863         # File info
    864         repo_id=repo_id,
    865         filename=filename,
    866         repo_type=repo_type,
    867         revision=revision,
    868         # HTTP info
    869         endpoint=endpoint,
    870         etag_timeout=etag_timeout,
    871         headers=hf_headers,
    872         proxies=proxies,
    873         token=token,
    874         # Additional options
    875         local_files_only=local_files_only,
    876         force_download=force_download,
    877     )

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:967, in _hf_hub_download_to_cache_dir(cache_dir, repo_id, filename, repo_type, revision, endpoint, etag_timeout, headers, proxies, token, local_files_only, force_download)
    966     # Otherwise, raise appropriate error
--> 967     _raise_on_head_call_error(head_call_error, force_download, local_files_only)
    969 # From now on, etag, commit_hash, url and size are not None.

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:1485, in _raise_on_head_call_error(head_call_error, force_download, local_files_only)
   1483 else:
   1484     # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
-> 1485     raise LocalEntryNotFoundError(
   1486         "An error happened while trying to locate the file on the Hub and we cannot find the requested files"
   1487         " in the local cache. Please check your connection and try again or make sure your Internet connection"
   1488         " is on."
   1489     ) from head_call_error

LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

The above exception was the direct cause of the following exception:

OSError                                   Traceback (most recent call last)
Cell In[3], line 3
      1 tokenizer_test = AutoTokenizer.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True)
      2 hf_model_test = AutoModelForCausalLM.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True).to(device)
----> 3 hooked_model = HookedTransformer.from_pretrained_no_processing(
      4         model_name="Qwen/Qwen2.5-0.5B-Instruct",
      5         tokenizer=tokenizer_test,
      6         hf_model=hf_model_test,
      7         device=device,
      8         dtype=torch.bfloat16,
      9     )

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/HookedTransformer.py:1374, in HookedTransformer.from_pretrained_no_processing(cls, model_name, fold_ln, center_writing_weights, center_unembed, refactor_factored_attn_matrices, fold_value_biases, dtype, default_prepend_bos, default_padding_side, **from_pretrained_kwargs)
   1355 @classmethod
   1356 def from_pretrained_no_processing(
   1357     cls,
   (...)
   1367     **from_pretrained_kwargs,
   1368 ):
   1369     """Wrapper for from_pretrained.
   1370 
   1371     Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
   1372     False. Refer to from_pretrained for details.
   1373     """
-> 1374     return cls.from_pretrained(
   1375         model_name,
   1376         fold_ln=fold_ln,
   1377         center_writing_weights=center_writing_weights,
   1378         center_unembed=center_unembed,
   1379         fold_value_biases=fold_value_biases,
   1380         refactor_factored_attn_matrices=refactor_factored_attn_matrices,
   1381         dtype=dtype,
   1382         default_prepend_bos=default_prepend_bos,
   1383         default_padding_side=default_padding_side,
   1384         **from_pretrained_kwargs,
   1385     )

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/HookedTransformer.py:1282, in HookedTransformer.from_pretrained(cls, model_name, fold_ln, center_writing_weights, center_unembed, refactor_factored_attn_matrices, checkpoint_index, checkpoint_value, hf_model, device, n_devices, tokenizer, move_to_device, fold_value_biases, default_prepend_bos, default_padding_side, dtype, first_n_layers, **from_pretrained_kwargs)
   1278 # Load the config into an HookedTransformerConfig object. If loading from a
   1279 # checkpoint, the config object will contain the information about the
   1280 # checkpoint
   1281 print("hf_cfg",hf_cfg)
-> 1282 cfg = loading.get_pretrained_model_config(
   1283     official_model_name,
   1284     hf_cfg=hf_cfg,
   1285     checkpoint_index=checkpoint_index,
   1286     checkpoint_value=checkpoint_value,
   1287     fold_ln=fold_ln,
   1288     device=device,
   1289     n_devices=n_devices,
   1290     default_prepend_bos=default_prepend_bos,
   1291     dtype=dtype,
   1292     first_n_layers=first_n_layers,
   1293     **from_pretrained_kwargs,
   1294 )
   1296 if cfg.positional_embedding_type == "shortformer":
   1297     if fold_ln:

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/loading_from_pretrained.py:1655, in get_pretrained_model_config(model_name, hf_cfg, checkpoint_index, checkpoint_value, fold_ln, device, n_devices, default_prepend_bos, dtype, first_n_layers, **kwargs)
   1653     print("convert_hf_model_config")
   1654     print(hf_cfg)
-> 1655     cfg_dict = convert_hf_model_config(official_model_name,hf_cfg,**kwargs)
   1656 # Processing common to both model types
   1657 # Remove any prefix, saying the organization who made a model.
   1658 cfg_dict["model_name"] = official_model_name.split("/")[-1]

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/loading_from_pretrained.py:766, in convert_hf_model_config(model_name, hf_cfg, **kwargs)
    756     #add by yhr :We can use official_model_name or the local path
    757     # if hf_cfg != None:
    758     #     print("use hf_cfg")
   (...)
    763     #     )
    764     # else:
    765     print("use official_model_name hf_cfg")
--> 766     hf_config = AutoConfig.from_pretrained(
    767             official_model_name,
    768             token=huggingface_token,
    769             **kwargs,
    770         )
    772     architecture = hf_config.architectures[0]
    774 if official_model_name.startswith(
    775     ("llama-7b", "meta-llama/Llama-2-7b")
    776 ):  # same architecture for LLaMA and Llama-2

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py:1054, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
   1051 trust_remote_code = kwargs.pop("trust_remote_code", None)
   1052 code_revision = kwargs.pop("code_revision", None)
-> 1054 config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
   1055 has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
   1056 has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/configuration_utils.py:591, in PretrainedConfig.get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    589 original_kwargs = copy.deepcopy(kwargs)
    590 # Get config dict associated with the base config file
--> 591 config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
    592 if config_dict is None:
    593     return {}, kwargs

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/configuration_utils.py:650, in PretrainedConfig._get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    646 configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
    648 try:
    649     # Load from local folder or from cache or download from model Hub and cache
--> 650     resolved_config_file = cached_file(
    651         pretrained_model_name_or_path,
    652         configuration_file,
    653         cache_dir=cache_dir,
    654         force_download=force_download,
    655         proxies=proxies,
    656         resume_download=resume_download,
    657         local_files_only=local_files_only,
    658         token=token,
    659         user_agent=user_agent,
    660         revision=revision,
    661         subfolder=subfolder,
    662         _commit_hash=commit_hash,
    663     )
    664     if resolved_config_file is None:
    665         return None, kwargs

File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/utils/hub.py:446, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    440     if (
    441         resolved_file is not None
    442         or not _raise_exceptions_for_missing_entries
    443         or not _raise_exceptions_for_connection_errors
    444     ):
    445         return resolved_file
--> 446     raise EnvironmentError(
    447         f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
    448         f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
    449         f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
    450         " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
    451     ) from e
    452 except EntryNotFoundError as e:
    453     if not _raise_exceptions_for_missing_entries:

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like Qwen/Qwen2.5-0.5B-Instruct is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

We can observe that the project loads the hf_config solely based on the official_model_name, unless the model is LLaMA, Gemma-2, or Gemma.

The code is in transformer_lens/loading_from_pretrained.py(line 747)

 # Load HuggingFace model config
    if "llama" in official_model_name.lower():
        architecture = "LlamaForCausalLM"
    elif "gemma-2" in official_model_name.lower():
        architecture = "Gemma2ForCausalLM"
    elif "gemma" in official_model_name.lower():
        architecture = "GemmaForCausalLM"
    else:
        huggingface_token = os.environ.get("HF_TOKEN", None)
        hf_config = AutoConfig.from_pretrained(
            official_model_name,
            token=huggingface_token,
            **kwargs,
        )
        architecture = hf_config.architectures[0]

Code example

We can add the additional parameter hf_config to the convert_hf_model_config in loading_from_pretrained.py function and add an if branch to check whether the model already has a config, so it doesn't need to query Hugging Face.convert_hf_model_config [(model_name, hf_config,**kwargs)]

def convert_hf_model_config(model_name: str,hf_cfg: Optional[dict] = None,**kwargs):
    """
    Returns the model config for a HuggingFace model, converted to a dictionary
    in the HookedTransformerConfig format.

    Takes the official_model_name as an input.
    """
    # In case the user passed in an alias
    if (Path(model_name) / "config.json").exists():
        logging.info("Loading model config from local directory")
        official_model_name = model_name
    else:
        logging.info("Loading model config from get_official_model_name")
        official_model_name = get_official_model_name(model_name)

    # Load HuggingFace model config
    if "llama" in official_model_name.lower():
        print("use LlamaForCausalLM architecture")
        architecture = "LlamaForCausalLM"
    elif "gemma-2" in official_model_name.lower():
        architecture = "Gemma2ForCausalLM"
    elif "gemma" in official_model_name.lower():
        architecture = "GemmaForCausalLM"
    else:
        huggingface_token = os.environ.get("HF_TOKEN", None)
        if hf_cfg != None:
            hf_config = AutoConfig.from_pretrained(
                hf_cfg.get("_name_or_path"),
                token=huggingface_token,
                **kwargs,
            )
        else:
            hf_config = AutoConfig.from_pretrained(
                official_model_name,
                token=huggingface_token,
                **kwargs,
            )
        
        architecture = hf_config.architectures[0]

When in get_pretrained_model_config, we send the hf_config to convert_hf_model_config :

def get_pretrained_model_config(
    model_name: str,
    hf_cfg: Optional[dict] = None,
    checkpoint_index: Optional[int] = None,
    checkpoint_value: Optional[int] = None,
    fold_ln: bool = False,
    device: Optional[Union[str, torch.device]] = None,
    n_devices: int = 1,
    default_prepend_bos: Optional[bool] = None,
    dtype: torch.dtype = torch.float32,
    first_n_layers: Optional[int] = None,
    **kwargs,
):
    """Returns the pretrained model config as an HookedTransformerConfig object.

    There are two types of pretrained models: HuggingFace models (where
    AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
    aren't as integrated with HuggingFace infrastructure.

    Args:
        model_name: The name of the model. This can be either the official
            HuggingFace model name, or the name of a model trained by me
            (NeelNanda).
        hf_cfg (dict, optional): Config of a loaded pretrained HF model,
            converted to a dictionary.
        checkpoint_index (int, optional): If loading from a
            checkpoint, the index of the checkpoint to load. Defaults to None.
        checkpoint_value (int, optional): If loading from a checkpoint, the
        value of
            the checkpoint to load, ie the step or token number (each model has
            checkpoints labelled with exactly one of these). Defaults to None.
        fold_ln (bool, optional): Whether to fold the layer norm into the
            subsequent linear layers (see HookedTransformer.fold_layer_norm for
            details). Defaults to False.
        device (str, optional): The device to load the model onto. By
            default will load to CUDA if available, else CPU.
        n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
        default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
            methods of HookedTransformer process input text to tokenize (only when input is a string).
            Resolution order for default_prepend_bos:
            1. If user passes value explicitly, use that value
            2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
            3. Global default (True)

            Even for models not explicitly trained with the BOS token, heads often use the
            first position as a resting position and accordingly lose information from the first token,
            so this empirically seems to give better results. Note that you can also locally override the default behavior
            by passing in prepend_bos=True/False when you call a method that processes the input string.
        dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
        kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
            Also given to other HuggingFace functions when compatible.

    """
    if Path(model_name).exists():
        # If the model_name is a path, it's a local model
        print("If the model_name is a path, it's a local model")
        cfg_dict = convert_hf_model_config(model_name, hf_config,**kwargs)
        official_model_name = model_name
    else:
        print("USE get_official_model_name")
        official_model_name = get_official_model_name(model_name)
        print("official_model_name-1", official_model_name)
    if (
        official_model_name.startswith("NeelNanda")
        or official_model_name.startswith("ArthurConmy")
        or official_model_name.startswith("Baidicoot")
    ):
        cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
    else:
        if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
            "trust_remote_code", False
        ):
            logging.warning(
                f"Loading model {official_model_name} requires setting trust_remote_code=True"
            )
            kwargs["trust_remote_code"] = True
        
        cfg_dict = convert_hf_model_config(official_model_name,hf_cfg,**kwargs)
    # Processing common to both model types
    # Remove any prefix, saying the organization who made a model.
    cfg_dict["model_name"] = official_model_name.split("/")[-1]
    print("official_model_name.split("'/'")[-1]",official_model_name.split("/")[-1])
    # Don't need to initialize weights, we're loading from pretrained
    cfg_dict["init_weights"] = False

    if (
        "positional_embedding_type" in cfg_dict
        and cfg_dict["positional_embedding_type"] == "shortformer"
        and fold_ln
    ):
        logging.warning(
            "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
        )
        fold_ln = False

    if device is not None:
        cfg_dict["device"] = device

    cfg_dict["dtype"] = dtype

    if fold_ln:
        if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
            cfg_dict["normalization_type"] = "LNPre"
        elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
            cfg_dict["normalization_type"] = "RMSPre"
        else:
            logging.warning("Cannot fold in layer norm, normalization_type is not LN.")

    if checkpoint_index is not None or checkpoint_value is not None:
        checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
            official_model_name,
            **kwargs,
        )
        cfg_dict["from_checkpoint"] = True
        cfg_dict["checkpoint_label_type"] = checkpoint_label_type
        if checkpoint_index is not None:
            cfg_dict["checkpoint_index"] = checkpoint_index
            cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
        elif checkpoint_value is not None:
            assert (
                checkpoint_value in checkpoint_labels
            ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
            cfg_dict["checkpoint_value"] = checkpoint_value
            cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
    else:
        cfg_dict["from_checkpoint"] = False

    cfg_dict["device"] = device
    cfg_dict["n_devices"] = n_devices

    if default_prepend_bos is not None:
        # User explicitly set prepend_bos behavior, override config/default value
        cfg_dict["default_prepend_bos"] = default_prepend_bos
    elif "default_prepend_bos" not in cfg_dict:
        # No config value or user override, set default value (True)
        cfg_dict["default_prepend_bos"] = True

    if hf_cfg is not None:
        cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
    if first_n_layers is not None:
        cfg_dict["n_layers"] = first_n_layers

    cfg = HookedTransformerConfig.from_dict(cfg_dict)
    print("cfg",cfg)
    return cfg

System Info

Additional context
Add any other context about the problem here.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    complexity-moderateModerately complicated issues for people who have intermediate experience with the codehigh-priorityMaintainers are interested in these issues being solved before others

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions