Description
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
When using Qwen series models (and potentially other models) with transformer_lens, even if the model is already loaded locally, the framework still attempts to fetch hf_config from Hugging Face's online repository. This behavior can cause issues in environments where internet access is restricted or unavailable, despite the model files being fully available locally.
tokenizer_test = AutoTokenizer.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True)
hf_model_test = AutoModelForCausalLM.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True).to(device)
hooked_model = HookedTransformer.from_pretrained_no_processing(
model_name="Qwen/Qwen2.5-0.5B-Instruct",
tokenizer=tokenizer_test,
hf_model=hf_model_test,
device=device,
dtype=torch.bfloat16,
)
---------------------------------------------------------------------------
TimeoutError Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connection.py:198, in HTTPConnection._new_conn(self)
197 try:
--> 198 sock = connection.create_connection(
199 (self._dns_host, self.port),
200 self.timeout,
201 source_address=self.source_address,
202 socket_options=self.socket_options,
203 )
204 except socket.gaierror as e:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/util/connection.py:85, in create_connection(address, timeout, source_address, socket_options)
84 try:
---> 85 raise err
86 finally:
87 # Break explicitly a reference cycle
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/util/connection.py:73, in create_connection(address, timeout, source_address, socket_options)
72 sock.bind(source_address)
---> 73 sock.connect(sa)
74 # Break explicitly a reference cycle
TimeoutError: timed out
The above exception was the direct cause of the following exception:
ConnectTimeoutError Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:787, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
786 # Make the request on the HTTPConnection object
--> 787 response = self._make_request(
788 conn,
789 method,
790 url,
791 timeout=timeout_obj,
792 body=body,
793 headers=headers,
794 chunked=chunked,
795 retries=retries,
796 response_conn=response_conn,
797 preload_content=preload_content,
798 decode_content=decode_content,
799 **response_kw,
800 )
802 # Everything went great!
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:488, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
487 new_e = _wrap_proxy_error(new_e, conn.proxy.scheme)
--> 488 raise new_e
490 # conn.request() calls http.client.*.request, not the method in
491 # urllib3.request. It also calls makefile (recv) on the socket.
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:464, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
463 try:
--> 464 self._validate_conn(conn)
465 except (SocketTimeout, BaseSSLError) as e:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:1093, in HTTPSConnectionPool._validate_conn(self, conn)
1092 if conn.is_closed:
-> 1093 conn.connect()
1095 # TODO revise this, see https://github.com/urllib3/urllib3/issues/2791
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connection.py:704, in HTTPSConnection.connect(self)
703 sock: socket.socket | ssl.SSLSocket
--> 704 self.sock = sock = self._new_conn()
705 server_hostname: str = self.host
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connection.py:207, in HTTPConnection._new_conn(self)
206 except SocketTimeout as e:
--> 207 raise ConnectTimeoutError(
208 self,
209 f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
210 ) from e
212 except OSError as e:
ConnectTimeoutError: (<urllib3.connection.HTTPSConnection object at 0x7ded1df1ecd0>, 'Connection to huggingface.co timed out. (connect timeout=10)')
The above exception was the direct cause of the following exception:
MaxRetryError Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/adapters.py:667, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
666 try:
--> 667 resp = conn.urlopen(
668 method=request.method,
669 url=url,
670 body=request.body,
671 headers=request.headers,
672 redirect=False,
673 assert_same_host=False,
674 preload_content=False,
675 decode_content=False,
676 retries=self.max_retries,
677 timeout=timeout,
678 chunked=chunked,
679 )
681 except (ProtocolError, OSError) as err:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/connectionpool.py:841, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
839 new_e = ProtocolError("Connection aborted.", new_e)
--> 841 retries = retries.increment(
842 method, url, error=new_e, _pool=self, _stacktrace=sys.exc_info()[2]
843 )
844 retries.sleep()
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/urllib3/util/retry.py:519, in Retry.increment(self, method, url, response, error, _pool, _stacktrace)
518 reason = error or ResponseError(cause)
--> 519 raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]
521 log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)
MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /Qwen/Qwen2.5-0.5B-Instruct/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7ded1df1ecd0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))
During handling of the above exception, another exception occurred:
ConnectTimeout Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:1374, in _get_metadata_or_catch_error(repo_id, filename, repo_type, revision, endpoint, proxies, etag_timeout, headers, token, local_files_only, relative_filename, storage_folder)
1373 try:
-> 1374 metadata = get_hf_file_metadata(
1375 url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token
1376 )
1377 except EntryNotFoundError as http_error:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
112 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:1294, in get_hf_file_metadata(url, token, proxies, timeout, library_name, library_version, user_agent, headers)
1293 # Retrieve metadata
-> 1294 r = _request_wrapper(
1295 method="HEAD",
1296 url=url,
1297 headers=hf_headers,
1298 allow_redirects=False,
1299 follow_relative_redirects=True,
1300 proxies=proxies,
1301 timeout=timeout,
1302 )
1303 hf_raise_for_status(r)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:278, in _request_wrapper(method, url, follow_relative_redirects, **params)
277 if follow_relative_redirects:
--> 278 response = _request_wrapper(
279 method=method,
280 url=url,
281 follow_relative_redirects=False,
282 **params,
283 )
285 # If redirection, we redirect only relative paths.
286 # This is useful in case of a renamed repository.
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:301, in _request_wrapper(method, url, follow_relative_redirects, **params)
300 # Perform request and return if status_code is not in the retry list.
--> 301 response = get_session().request(method=method, url=url, **params)
302 hf_raise_for_status(response)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
588 send_kwargs.update(settings)
--> 589 resp = self.send(prep, **send_kwargs)
591 return resp
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs)
702 # Send the request
--> 703 r = adapter.send(request, **kwargs)
705 # Total elapsed time of the request (approximately)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:93, in UniqueRequestIdAdapter.send(self, request, *args, **kwargs)
92 try:
---> 93 return super().send(request, *args, **kwargs)
94 except requests.RequestException as e:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/requests/adapters.py:688, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
687 if not isinstance(e.reason, NewConnectionError):
--> 688 raise ConnectTimeout(e, request=request)
690 if isinstance(e.reason, ResponseError):
ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /Qwen/Qwen2.5-0.5B-Instruct/resolve/main/config.json (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7ded1df1ecd0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 4b843dda-131e-44bd-936c-cc92ab8e626e)')
The above exception was the direct cause of the following exception:
LocalEntryNotFoundError Traceback (most recent call last)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/utils/hub.py:403, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
401 try:
402 # Load from URL or cache if already cached
--> 403 resolved_file = hf_hub_download(
404 path_or_repo_id,
405 filename,
406 subfolder=None if len(subfolder) == 0 else subfolder,
407 repo_type=repo_type,
408 revision=revision,
409 cache_dir=cache_dir,
410 user_agent=user_agent,
411 force_download=force_download,
412 proxies=proxies,
413 resume_download=resume_download,
414 token=token,
415 local_files_only=local_files_only,
416 )
417 except GatedRepoError as e:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
112 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:860, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, proxies, etag_timeout, token, local_files_only, headers, endpoint, resume_download, force_filename, local_dir_use_symlinks)
859 else:
--> 860 return _hf_hub_download_to_cache_dir(
861 # Destination
862 cache_dir=cache_dir,
863 # File info
864 repo_id=repo_id,
865 filename=filename,
866 repo_type=repo_type,
867 revision=revision,
868 # HTTP info
869 endpoint=endpoint,
870 etag_timeout=etag_timeout,
871 headers=hf_headers,
872 proxies=proxies,
873 token=token,
874 # Additional options
875 local_files_only=local_files_only,
876 force_download=force_download,
877 )
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:967, in _hf_hub_download_to_cache_dir(cache_dir, repo_id, filename, repo_type, revision, endpoint, etag_timeout, headers, proxies, token, local_files_only, force_download)
966 # Otherwise, raise appropriate error
--> 967 _raise_on_head_call_error(head_call_error, force_download, local_files_only)
969 # From now on, etag, commit_hash, url and size are not None.
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/huggingface_hub/file_download.py:1485, in _raise_on_head_call_error(head_call_error, force_download, local_files_only)
1483 else:
1484 # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
-> 1485 raise LocalEntryNotFoundError(
1486 "An error happened while trying to locate the file on the Hub and we cannot find the requested files"
1487 " in the local cache. Please check your connection and try again or make sure your Internet connection"
1488 " is on."
1489 ) from head_call_error
LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
The above exception was the direct cause of the following exception:
OSError Traceback (most recent call last)
Cell In[3], line 3
1 tokenizer_test = AutoTokenizer.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True)
2 hf_model_test = AutoModelForCausalLM.from_pretrained("/home/lijiaming/workspace/_store/models/Qwen2.5-0.5B-Instruct",local_files_only=True).to(device)
----> 3 hooked_model = HookedTransformer.from_pretrained_no_processing(
4 model_name="Qwen/Qwen2.5-0.5B-Instruct",
5 tokenizer=tokenizer_test,
6 hf_model=hf_model_test,
7 device=device,
8 dtype=torch.bfloat16,
9 )
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/HookedTransformer.py:1374, in HookedTransformer.from_pretrained_no_processing(cls, model_name, fold_ln, center_writing_weights, center_unembed, refactor_factored_attn_matrices, fold_value_biases, dtype, default_prepend_bos, default_padding_side, **from_pretrained_kwargs)
1355 @classmethod
1356 def from_pretrained_no_processing(
1357 cls,
(...)
1367 **from_pretrained_kwargs,
1368 ):
1369 """Wrapper for from_pretrained.
1370
1371 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1372 False. Refer to from_pretrained for details.
1373 """
-> 1374 return cls.from_pretrained(
1375 model_name,
1376 fold_ln=fold_ln,
1377 center_writing_weights=center_writing_weights,
1378 center_unembed=center_unembed,
1379 fold_value_biases=fold_value_biases,
1380 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1381 dtype=dtype,
1382 default_prepend_bos=default_prepend_bos,
1383 default_padding_side=default_padding_side,
1384 **from_pretrained_kwargs,
1385 )
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/HookedTransformer.py:1282, in HookedTransformer.from_pretrained(cls, model_name, fold_ln, center_writing_weights, center_unembed, refactor_factored_attn_matrices, checkpoint_index, checkpoint_value, hf_model, device, n_devices, tokenizer, move_to_device, fold_value_biases, default_prepend_bos, default_padding_side, dtype, first_n_layers, **from_pretrained_kwargs)
1278 # Load the config into an HookedTransformerConfig object. If loading from a
1279 # checkpoint, the config object will contain the information about the
1280 # checkpoint
1281 print("hf_cfg",hf_cfg)
-> 1282 cfg = loading.get_pretrained_model_config(
1283 official_model_name,
1284 hf_cfg=hf_cfg,
1285 checkpoint_index=checkpoint_index,
1286 checkpoint_value=checkpoint_value,
1287 fold_ln=fold_ln,
1288 device=device,
1289 n_devices=n_devices,
1290 default_prepend_bos=default_prepend_bos,
1291 dtype=dtype,
1292 first_n_layers=first_n_layers,
1293 **from_pretrained_kwargs,
1294 )
1296 if cfg.positional_embedding_type == "shortformer":
1297 if fold_ln:
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/loading_from_pretrained.py:1655, in get_pretrained_model_config(model_name, hf_cfg, checkpoint_index, checkpoint_value, fold_ln, device, n_devices, default_prepend_bos, dtype, first_n_layers, **kwargs)
1653 print("convert_hf_model_config")
1654 print(hf_cfg)
-> 1655 cfg_dict = convert_hf_model_config(official_model_name,hf_cfg,**kwargs)
1656 # Processing common to both model types
1657 # Remove any prefix, saying the organization who made a model.
1658 cfg_dict["model_name"] = official_model_name.split("/")[-1]
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformer_lens/loading_from_pretrained.py:766, in convert_hf_model_config(model_name, hf_cfg, **kwargs)
756 #add by yhr :We can use official_model_name or the local path
757 # if hf_cfg != None:
758 # print("use hf_cfg")
(...)
763 # )
764 # else:
765 print("use official_model_name hf_cfg")
--> 766 hf_config = AutoConfig.from_pretrained(
767 official_model_name,
768 token=huggingface_token,
769 **kwargs,
770 )
772 architecture = hf_config.architectures[0]
774 if official_model_name.startswith(
775 ("llama-7b", "meta-llama/Llama-2-7b")
776 ): # same architecture for LLaMA and Llama-2
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py:1054, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
1051 trust_remote_code = kwargs.pop("trust_remote_code", None)
1052 code_revision = kwargs.pop("code_revision", None)
-> 1054 config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
1055 has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
1056 has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/configuration_utils.py:591, in PretrainedConfig.get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
589 original_kwargs = copy.deepcopy(kwargs)
590 # Get config dict associated with the base config file
--> 591 config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
592 if config_dict is None:
593 return {}, kwargs
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/configuration_utils.py:650, in PretrainedConfig._get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
646 configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
648 try:
649 # Load from local folder or from cache or download from model Hub and cache
--> 650 resolved_config_file = cached_file(
651 pretrained_model_name_or_path,
652 configuration_file,
653 cache_dir=cache_dir,
654 force_download=force_download,
655 proxies=proxies,
656 resume_download=resume_download,
657 local_files_only=local_files_only,
658 token=token,
659 user_agent=user_agent,
660 revision=revision,
661 subfolder=subfolder,
662 _commit_hash=commit_hash,
663 )
664 if resolved_config_file is None:
665 return None, kwargs
File ~/miniconda3/envs/SAELens/lib/python3.11/site-packages/transformers/utils/hub.py:446, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
440 if (
441 resolved_file is not None
442 or not _raise_exceptions_for_missing_entries
443 or not _raise_exceptions_for_connection_errors
444 ):
445 return resolved_file
--> 446 raise EnvironmentError(
447 f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
448 f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
449 f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
450 " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
451 ) from e
452 except EntryNotFoundError as e:
453 if not _raise_exceptions_for_missing_entries:
OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like Qwen/Qwen2.5-0.5B-Instruct is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.
We can observe that the project loads the hf_config
solely based on the official_model_name
, unless the model is LLaMA, Gemma-2, or Gemma.
The code is in transformer_lens/loading_from_pretrained.py
(line 747)
# Load HuggingFace model config
if "llama" in official_model_name.lower():
architecture = "LlamaForCausalLM"
elif "gemma-2" in official_model_name.lower():
architecture = "Gemma2ForCausalLM"
elif "gemma" in official_model_name.lower():
architecture = "GemmaForCausalLM"
else:
huggingface_token = os.environ.get("HF_TOKEN", None)
hf_config = AutoConfig.from_pretrained(
official_model_name,
token=huggingface_token,
**kwargs,
)
architecture = hf_config.architectures[0]
Code example
We can add the additional parameter hf_config
to the convert_hf_model_config
in loading_from_pretrained.py
function and add an if branch to check whether the model already has a config, so it doesn't need to query Hugging Face.convert_hf_model_config [(model_name, hf_config,**kwargs)]
def convert_hf_model_config(model_name: str,hf_cfg: Optional[dict] = None,**kwargs):
"""
Returns the model config for a HuggingFace model, converted to a dictionary
in the HookedTransformerConfig format.
Takes the official_model_name as an input.
"""
# In case the user passed in an alias
if (Path(model_name) / "config.json").exists():
logging.info("Loading model config from local directory")
official_model_name = model_name
else:
logging.info("Loading model config from get_official_model_name")
official_model_name = get_official_model_name(model_name)
# Load HuggingFace model config
if "llama" in official_model_name.lower():
print("use LlamaForCausalLM architecture")
architecture = "LlamaForCausalLM"
elif "gemma-2" in official_model_name.lower():
architecture = "Gemma2ForCausalLM"
elif "gemma" in official_model_name.lower():
architecture = "GemmaForCausalLM"
else:
huggingface_token = os.environ.get("HF_TOKEN", None)
if hf_cfg != None:
hf_config = AutoConfig.from_pretrained(
hf_cfg.get("_name_or_path"),
token=huggingface_token,
**kwargs,
)
else:
hf_config = AutoConfig.from_pretrained(
official_model_name,
token=huggingface_token,
**kwargs,
)
architecture = hf_config.architectures[0]
When in get_pretrained_model_config
, we send the hf_config to convert_hf_model_config
:
def get_pretrained_model_config(
model_name: str,
hf_cfg: Optional[dict] = None,
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
fold_ln: bool = False,
device: Optional[Union[str, torch.device]] = None,
n_devices: int = 1,
default_prepend_bos: Optional[bool] = None,
dtype: torch.dtype = torch.float32,
first_n_layers: Optional[int] = None,
**kwargs,
):
"""Returns the pretrained model config as an HookedTransformerConfig object.
There are two types of pretrained models: HuggingFace models (where
AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
aren't as integrated with HuggingFace infrastructure.
Args:
model_name: The name of the model. This can be either the official
HuggingFace model name, or the name of a model trained by me
(NeelNanda).
hf_cfg (dict, optional): Config of a loaded pretrained HF model,
converted to a dictionary.
checkpoint_index (int, optional): If loading from a
checkpoint, the index of the checkpoint to load. Defaults to None.
checkpoint_value (int, optional): If loading from a checkpoint, the
value of
the checkpoint to load, ie the step or token number (each model has
checkpoints labelled with exactly one of these). Defaults to None.
fold_ln (bool, optional): Whether to fold the layer norm into the
subsequent linear layers (see HookedTransformer.fold_layer_norm for
details). Defaults to False.
device (str, optional): The device to load the model onto. By
default will load to CUDA if available, else CPU.
n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
methods of HookedTransformer process input text to tokenize (only when input is a string).
Resolution order for default_prepend_bos:
1. If user passes value explicitly, use that value
2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
3. Global default (True)
Even for models not explicitly trained with the BOS token, heads often use the
first position as a resting position and accordingly lose information from the first token,
so this empirically seems to give better results. Note that you can also locally override the default behavior
by passing in prepend_bos=True/False when you call a method that processes the input string.
dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
Also given to other HuggingFace functions when compatible.
"""
if Path(model_name).exists():
# If the model_name is a path, it's a local model
print("If the model_name is a path, it's a local model")
cfg_dict = convert_hf_model_config(model_name, hf_config,**kwargs)
official_model_name = model_name
else:
print("USE get_official_model_name")
official_model_name = get_official_model_name(model_name)
print("official_model_name-1", official_model_name)
if (
official_model_name.startswith("NeelNanda")
or official_model_name.startswith("ArthurConmy")
or official_model_name.startswith("Baidicoot")
):
cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
else:
if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
"trust_remote_code", False
):
logging.warning(
f"Loading model {official_model_name} requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True
cfg_dict = convert_hf_model_config(official_model_name,hf_cfg,**kwargs)
# Processing common to both model types
# Remove any prefix, saying the organization who made a model.
cfg_dict["model_name"] = official_model_name.split("/")[-1]
print("official_model_name.split("'/'")[-1]",official_model_name.split("/")[-1])
# Don't need to initialize weights, we're loading from pretrained
cfg_dict["init_weights"] = False
if (
"positional_embedding_type" in cfg_dict
and cfg_dict["positional_embedding_type"] == "shortformer"
and fold_ln
):
logging.warning(
"You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
)
fold_ln = False
if device is not None:
cfg_dict["device"] = device
cfg_dict["dtype"] = dtype
if fold_ln:
if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
cfg_dict["normalization_type"] = "LNPre"
elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
cfg_dict["normalization_type"] = "RMSPre"
else:
logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
if checkpoint_index is not None or checkpoint_value is not None:
checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
official_model_name,
**kwargs,
)
cfg_dict["from_checkpoint"] = True
cfg_dict["checkpoint_label_type"] = checkpoint_label_type
if checkpoint_index is not None:
cfg_dict["checkpoint_index"] = checkpoint_index
cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
elif checkpoint_value is not None:
assert (
checkpoint_value in checkpoint_labels
), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
cfg_dict["checkpoint_value"] = checkpoint_value
cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
else:
cfg_dict["from_checkpoint"] = False
cfg_dict["device"] = device
cfg_dict["n_devices"] = n_devices
if default_prepend_bos is not None:
# User explicitly set prepend_bos behavior, override config/default value
cfg_dict["default_prepend_bos"] = default_prepend_bos
elif "default_prepend_bos" not in cfg_dict:
# No config value or user override, set default value (True)
cfg_dict["default_prepend_bos"] = True
if hf_cfg is not None:
cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
if first_n_layers is not None:
cfg_dict["n_layers"] = first_n_layers
cfg = HookedTransformerConfig.from_dict(cfg_dict)
print("cfg",cfg)
return cfg
System Info
Additional context
Add any other context about the problem here.
Checklist
- I have checked that there is no similar issue in the repo (required)