mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
* refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor
573 lines
22 KiB
Python
573 lines
22 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
A generic script to migrate LeRobot policies with built-in normalization layers to the new
|
|
pipeline-based processor system.
|
|
|
|
This script performs the following steps:
|
|
1. Loads a pretrained policy model and its configuration from a local path or the
|
|
Hugging Face Hub.
|
|
2. Scans the model's state dictionary to extract normalization statistics (e.g., mean,
|
|
std, min, max) for all features.
|
|
3. Creates two new processor pipelines:
|
|
- A preprocessor that normalizes inputs (observations) and outputs (actions).
|
|
- A postprocessor that unnormalizes outputs (actions) for inference.
|
|
4. Removes the original normalization layers from the model's state dictionary,
|
|
creating a "clean" model.
|
|
5. Saves the new clean model, the preprocessor, the postprocessor, and a generated
|
|
model card to a new directory.
|
|
6. Optionally pushes all the new artifacts to the Hugging Face Hub.
|
|
|
|
Usage:
|
|
python src/lerobot/processor/migrate_policy_normalization.py \
|
|
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
|
--policy-type act \
|
|
--push-to-hub
|
|
"""
|
|
|
|
import argparse
|
|
import importlib
|
|
import json
|
|
import os
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import torch
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file as load_safetensors
|
|
|
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
|
|
from .batch_processor import AddBatchDimensionProcessorStep
|
|
from .device_processor import DeviceProcessorStep
|
|
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
|
from .pipeline import PolicyProcessorPipeline
|
|
from .rename_processor import RenameObservationsProcessorStep
|
|
|
|
# Policy type to class mapping
|
|
POLICY_CLASSES = {
|
|
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
|
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
|
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
|
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
|
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
|
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
|
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
|
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
|
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
|
}
|
|
|
|
|
|
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
|
"""
|
|
Scans a model's state_dict to find and extract normalization statistics.
|
|
|
|
This function identifies keys corresponding to normalization layers (e.g., those
|
|
for mean, std, min, max) based on a set of predefined patterns and organizes
|
|
them into a nested dictionary.
|
|
|
|
Args:
|
|
state_dict: The state dictionary of a pretrained policy model.
|
|
|
|
Returns:
|
|
A nested dictionary where outer keys are feature names (e.g.,
|
|
'observation.state') and inner keys are statistic types ('mean', 'std'),
|
|
mapping to their corresponding tensor values.
|
|
"""
|
|
stats = {}
|
|
|
|
# Define patterns to match and their prefixes to remove
|
|
normalization_patterns = [
|
|
"normalize_inputs.buffer_",
|
|
"unnormalize_outputs.buffer_",
|
|
"normalize_targets.buffer_",
|
|
"normalize.", # Must come after normalize_* patterns
|
|
"unnormalize.", # Must come after unnormalize_* patterns
|
|
"input_normalizer.",
|
|
"output_normalizer.",
|
|
]
|
|
|
|
# Process each key in state_dict
|
|
for key, tensor in state_dict.items():
|
|
# Try each pattern
|
|
for pattern in normalization_patterns:
|
|
if key.startswith(pattern):
|
|
# Extract the remaining part after the pattern
|
|
remaining = key[len(pattern) :]
|
|
parts = remaining.split(".")
|
|
|
|
# Need at least feature name and stat type
|
|
if len(parts) >= 2:
|
|
# Last part is the stat type (mean, std, min, max, etc.)
|
|
stat_type = parts[-1]
|
|
# Everything else is the feature name
|
|
feature_name = ".".join(parts[:-1]).replace("_", ".")
|
|
|
|
# Add to stats
|
|
if feature_name not in stats:
|
|
stats[feature_name] = {}
|
|
stats[feature_name][stat_type] = tensor.clone()
|
|
|
|
# Only process the first matching pattern
|
|
break
|
|
|
|
return stats
|
|
|
|
|
|
def detect_features_and_norm_modes(
|
|
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
|
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
|
"""
|
|
Infers policy features and normalization modes from the model config and stats.
|
|
|
|
This function first attempts to find feature definitions and normalization
|
|
mappings directly from the policy's configuration file. If this information is
|
|
not present, it infers it from the extracted normalization statistics, using
|
|
tensor shapes to determine feature shapes and the presence of specific stat
|
|
keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode.
|
|
It applies sensible defaults if inference is not possible.
|
|
|
|
Args:
|
|
config: The policy's configuration dictionary from `config.json`.
|
|
stats: The normalization statistics extracted from the model's state_dict.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- A dictionary mapping feature names to `PolicyFeature` objects.
|
|
- A dictionary mapping `FeatureType` enums to `NormalizationMode` enums.
|
|
"""
|
|
features = {}
|
|
norm_modes = {}
|
|
|
|
# First, check if there's a normalization_mapping in the config
|
|
if "normalization_mapping" in config:
|
|
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
|
# Extract normalization modes from config
|
|
for feature_name, mode_str in config["normalization_mapping"].items():
|
|
# Convert string to NormalizationMode enum
|
|
if mode_str == "mean_std":
|
|
mode = NormalizationMode.MEAN_STD
|
|
elif mode_str == "min_max":
|
|
mode = NormalizationMode.MIN_MAX
|
|
else:
|
|
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
|
continue
|
|
|
|
# Determine feature type from feature name
|
|
if "image" in feature_name or "visual" in feature_name:
|
|
feature_type = FeatureType.VISUAL
|
|
elif "state" in feature_name:
|
|
feature_type = FeatureType.STATE
|
|
elif "action" in feature_name:
|
|
feature_type = FeatureType.ACTION
|
|
else:
|
|
feature_type = FeatureType.STATE
|
|
|
|
norm_modes[feature_type] = mode
|
|
|
|
# Try to extract from config
|
|
if "features" in config:
|
|
for key, feature_config in config["features"].items():
|
|
shape = feature_config.get("shape", feature_config.get("dim"))
|
|
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
|
|
|
# Determine feature type
|
|
if "image" in key or "visual" in key:
|
|
feature_type = FeatureType.VISUAL
|
|
elif "state" in key:
|
|
feature_type = FeatureType.STATE
|
|
elif "action" in key:
|
|
feature_type = FeatureType.ACTION
|
|
else:
|
|
feature_type = FeatureType.STATE # Default
|
|
|
|
features[key] = PolicyFeature(feature_type, shape)
|
|
|
|
# If no features in config, infer from stats
|
|
if not features:
|
|
for key, stat_dict in stats.items():
|
|
# Get shape from any stat tensor
|
|
tensor = next(iter(stat_dict.values()))
|
|
shape = tuple(tensor.shape)
|
|
|
|
# Determine feature type based on key
|
|
if "image" in key or "visual" in key or "pixels" in key:
|
|
feature_type = FeatureType.VISUAL
|
|
elif "state" in key or "joint" in key or "position" in key:
|
|
feature_type = FeatureType.STATE
|
|
elif "action" in key:
|
|
feature_type = FeatureType.ACTION
|
|
else:
|
|
feature_type = FeatureType.STATE
|
|
|
|
features[key] = PolicyFeature(feature_type, shape)
|
|
|
|
# If normalization modes weren't in config, determine based on available stats
|
|
if not norm_modes:
|
|
for key, stat_dict in stats.items():
|
|
if key in features:
|
|
if "mean" in stat_dict and "std" in stat_dict:
|
|
feature_type = features[key].type
|
|
if feature_type not in norm_modes:
|
|
norm_modes[feature_type] = NormalizationMode.MEAN_STD
|
|
elif "min" in stat_dict and "max" in stat_dict:
|
|
feature_type = features[key].type
|
|
if feature_type not in norm_modes:
|
|
norm_modes[feature_type] = NormalizationMode.MIN_MAX
|
|
|
|
# Default normalization modes if not detected
|
|
if FeatureType.VISUAL not in norm_modes:
|
|
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
|
|
if FeatureType.STATE not in norm_modes:
|
|
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
|
|
if FeatureType.ACTION not in norm_modes:
|
|
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
|
|
|
|
return features, norm_modes
|
|
|
|
|
|
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Creates a new state_dict with all normalization-related layers removed.
|
|
|
|
This function filters the original state dictionary, excluding any keys that
|
|
match a set of predefined patterns associated with normalization modules.
|
|
|
|
Args:
|
|
state_dict: The original model state dictionary.
|
|
|
|
Returns:
|
|
A new state dictionary containing only the core model weights, without
|
|
any normalization parameters.
|
|
"""
|
|
new_state_dict = {}
|
|
|
|
# Patterns to remove
|
|
remove_patterns = [
|
|
"normalize_inputs.",
|
|
"unnormalize_outputs.",
|
|
"normalize_targets.", # Added pattern for target normalization
|
|
"normalize.",
|
|
"unnormalize.",
|
|
"input_normalizer.",
|
|
"output_normalizer.",
|
|
"normalizer.",
|
|
]
|
|
|
|
for key, tensor in state_dict.items():
|
|
should_remove = any(pattern in key for pattern in remove_patterns)
|
|
if not should_remove:
|
|
new_state_dict[key] = tensor
|
|
|
|
return new_state_dict
|
|
|
|
|
|
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
|
"""
|
|
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
|
|
|
Args:
|
|
features_dict: The feature dictionary in the old format, where values are
|
|
simple dictionaries (e.g., `{"shape": [7]}`).
|
|
|
|
Returns:
|
|
A dictionary mapping feature names to `PolicyFeature` dataclass objects.
|
|
"""
|
|
converted_features = {}
|
|
|
|
for key, feature_dict in features_dict.items():
|
|
# Determine feature type based on key
|
|
if "image" in key or "visual" in key:
|
|
feature_type = FeatureType.VISUAL
|
|
elif "state" in key:
|
|
feature_type = FeatureType.STATE
|
|
elif "action" in key:
|
|
feature_type = FeatureType.ACTION
|
|
else:
|
|
feature_type = FeatureType.STATE
|
|
|
|
# Get shape from feature dict
|
|
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
|
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
|
|
|
converted_features[key] = PolicyFeature(feature_type, shape)
|
|
|
|
return converted_features
|
|
|
|
|
|
def load_model_from_hub(
|
|
repo_id: str, revision: str = None
|
|
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
|
"""
|
|
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
|
|
|
Args:
|
|
repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha').
|
|
revision: The specific git revision (branch, tag, or commit hash) to use.
|
|
|
|
Returns:
|
|
A tuple containing the model's state dictionary, the policy configuration,
|
|
and the training configuration.
|
|
"""
|
|
# Download files.
|
|
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
|
|
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
|
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
|
|
|
# Load state_dict
|
|
state_dict = load_safetensors(safetensors_path)
|
|
|
|
# Load config
|
|
with open(config_path) as f:
|
|
config = json.load(f)
|
|
|
|
with open(train_config_path) as f:
|
|
train_config = json.load(f)
|
|
|
|
return state_dict, config, train_config
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Migrate policy models with normalization layers to new pipeline system"
|
|
)
|
|
parser.add_argument(
|
|
"--pretrained-path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to pretrained model (hub repo or local directory)",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Output directory for migrated model (default: same as pretrained-path)",
|
|
)
|
|
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
|
|
parser.add_argument(
|
|
"--hub-repo-id",
|
|
type=str,
|
|
default=None,
|
|
help="Hub repository ID for pushing (default: same as pretrained-path)",
|
|
)
|
|
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
|
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load model and config
|
|
print(f"Loading model from {args.pretrained_path}...")
|
|
if os.path.isdir(args.pretrained_path):
|
|
# Local directory
|
|
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
|
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
|
config = json.load(f)
|
|
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
|
train_config = json.load(f)
|
|
else:
|
|
# Hub repository
|
|
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
|
|
|
# Extract normalization statistics
|
|
print("Extracting normalization statistics...")
|
|
stats = extract_normalization_stats(state_dict)
|
|
|
|
print(f"Found normalization statistics for: {list(stats.keys())}")
|
|
|
|
# Detect input features and normalization modes
|
|
print("Detecting features and normalization modes...")
|
|
features, norm_map = detect_features_and_norm_modes(config, stats)
|
|
|
|
print(f"Detected features: {list(features.keys())}")
|
|
print(f"Normalization modes: {norm_map}")
|
|
|
|
# Remove normalization layers from state_dict
|
|
print("Removing normalization layers from model...")
|
|
new_state_dict = remove_normalization_layers(state_dict)
|
|
|
|
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
|
if removed_keys:
|
|
print(f"Removed {len(removed_keys)} normalization layer keys")
|
|
|
|
# Determine output path
|
|
if args.output_dir:
|
|
output_dir = Path(args.output_dir)
|
|
else:
|
|
if os.path.isdir(args.pretrained_path):
|
|
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
|
|
else:
|
|
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Clean up config - remove normalization_mapping field
|
|
cleaned_config = dict(config)
|
|
if "normalization_mapping" in cleaned_config:
|
|
print("Removing 'normalization_mapping' field from config")
|
|
del cleaned_config["normalization_mapping"]
|
|
policy_type = deepcopy(cleaned_config["type"])
|
|
|
|
del cleaned_config["type"]
|
|
|
|
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
|
print(f"Instantiating {policy_type} policy model...")
|
|
policy_class_path = POLICY_CLASSES[policy_type]
|
|
module_path, class_name = policy_class_path.rsplit(".", 1)
|
|
|
|
module = importlib.import_module(module_path)
|
|
policy_class = getattr(module, class_name)
|
|
|
|
# Create config class instance
|
|
config_module_path = module_path.replace("modeling", "configuration")
|
|
config_module = importlib.import_module(config_module_path)
|
|
# Handle special cases for config class names
|
|
config_class_names = {
|
|
"act": "ACTConfig",
|
|
"diffusion": "DiffusionConfig",
|
|
"pi0": "PI0Config",
|
|
"pi0fast": "PI0FASTConfig",
|
|
"smolvla": "SmolVLAConfig",
|
|
"tdmpc": "TDMPCConfig",
|
|
"vqbet": "VQBeTConfig",
|
|
"sac": "SACConfig",
|
|
"classifier": "ClassifierConfig",
|
|
}
|
|
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
|
config_class = getattr(config_module, config_class_name)
|
|
|
|
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
|
if "input_features" not in cleaned_config:
|
|
raise ValueError("Missing mandatory 'input_features' in config")
|
|
if "output_features" not in cleaned_config:
|
|
raise ValueError("Missing mandatory 'output_features' in config")
|
|
|
|
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
|
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
|
|
|
# Create config instance from cleaned config dict
|
|
policy_config = config_class(**cleaned_config)
|
|
|
|
# Create policy instance - some policies expect dataset_stats
|
|
policy = policy_class(policy_config)
|
|
|
|
# Load the cleaned state dict
|
|
policy.load_state_dict(new_state_dict, strict=True)
|
|
print("Successfully loaded cleaned state dict into policy model")
|
|
|
|
# Now create preprocessor and postprocessor with cleaned_config available
|
|
print("Creating preprocessor and postprocessor...")
|
|
# The pattern from existing processor factories:
|
|
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
|
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
|
|
|
# Get features from cleaned_config (now they're PolicyFeature objects)
|
|
input_features = cleaned_config.get("input_features", {})
|
|
output_features = cleaned_config.get("output_features", {})
|
|
|
|
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
|
preprocessor_steps = [
|
|
RenameObservationsProcessorStep(rename_map={}),
|
|
NormalizerProcessorStep(
|
|
features={**input_features, **output_features},
|
|
norm_map=norm_map,
|
|
stats=stats,
|
|
),
|
|
AddBatchDimensionProcessorStep(),
|
|
DeviceProcessorStep(device=policy_config.device),
|
|
]
|
|
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
|
|
|
# Create postprocessor with unnormalizer for outputs only
|
|
postprocessor_steps = [
|
|
DeviceProcessorStep(device="cpu"),
|
|
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
|
]
|
|
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
|
|
|
# Determine hub repo ID if pushing to hub
|
|
if args.push_to_hub:
|
|
if args.hub_repo_id:
|
|
hub_repo_id = args.hub_repo_id
|
|
else:
|
|
if not os.path.isdir(args.pretrained_path):
|
|
# Use same repo with "_migrated" suffix
|
|
hub_repo_id = f"{args.pretrained_path}_migrated"
|
|
else:
|
|
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
|
else:
|
|
hub_repo_id = None
|
|
|
|
# Save preprocessor and postprocessor to root directory
|
|
print(f"Saving preprocessor to {output_dir}...")
|
|
preprocessor.save_pretrained(output_dir)
|
|
if args.push_to_hub:
|
|
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
|
|
|
print(f"Saving postprocessor to {output_dir}...")
|
|
postprocessor.save_pretrained(output_dir)
|
|
if args.push_to_hub:
|
|
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
|
|
|
# Save model using the policy's save_pretrained method
|
|
print(f"Saving model to {output_dir}...")
|
|
policy.save_pretrained(
|
|
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
|
)
|
|
|
|
# Generate and save model card
|
|
print("Generating model card...")
|
|
# Get metadata from original config
|
|
dataset_repo_id = train_config.get("repo_id", "unknown")
|
|
license = config.get("license", "apache-2.0")
|
|
|
|
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
|
tags = set(tags).union({"robotics", "lerobot", policy_type})
|
|
tags = list(tags)
|
|
|
|
# Generate model card
|
|
card = policy.generate_model_card(
|
|
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
|
|
)
|
|
|
|
# Save model card locally
|
|
card.save(str(output_dir / "README.md"))
|
|
print(f"Model card saved to {output_dir / 'README.md'}")
|
|
# Push model card to hub if requested
|
|
if args.push_to_hub:
|
|
from huggingface_hub import HfApi
|
|
|
|
api = HfApi()
|
|
api.upload_file(
|
|
path_or_fileobj=str(output_dir / "README.md"),
|
|
path_in_repo="README.md",
|
|
repo_id=hub_repo_id,
|
|
repo_type="model",
|
|
commit_message="Add model card for migrated model",
|
|
)
|
|
print("Model card pushed to hub")
|
|
|
|
print("\nMigration complete!")
|
|
print(f"Migrated model saved to: {output_dir}")
|
|
if args.push_to_hub:
|
|
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|