-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pruning configuration and update requirements; enhance logging an…
…d example files
- Loading branch information
Showing
18 changed files
with
699 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,6 @@ venv | |
/**/.vscode/ | ||
/**/.idea/ | ||
/**/.mypy_cache/ | ||
prune* | ||
prune*/ | ||
logs/* | ||
configs/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat" | ||
prune_percent: 0.2 | ||
dtype: "float32" | ||
cache_dir: null | ||
device: "cuda" | ||
output: "results/pruned_model" | ||
apply_chat_template: true | ||
prompt: "What is the capital of France?" | ||
max_new_tokens: 50 | ||
target_size: null | ||
log_dir: "logs" | ||
test_only: false | ||
prune_method: "mk_prune" | ||
use_normalized_weights: false | ||
use_layer_norm_tweaks: false | ||
layer_norm_scale: 4.0 | ||
gate_up_weight_weights: [1.0, 1.0] | ||
print_summary: false | ||
quiet: false | ||
stop_logging: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat" | ||
prune_percent: 0.2 | ||
dtype: "float32" | ||
cache_dir: null | ||
device: "cuda" | ||
output: "results/pruned_model" | ||
apply_chat_template: true | ||
prompt: "What is the capital of France?" | ||
max_new_tokens: 50 | ||
target_size: null | ||
log_dir: "logs" | ||
test_only: false | ||
prune_method: "mk" | ||
use_normalized_weights: false | ||
use_layer_norm_tweaks: false | ||
layer_norm_scale: 4.0 | ||
gate_up_weight_weights: [1.0, 1.0] | ||
print_summary: false | ||
quiet: false | ||
stop_logging: false | ||
|
||
# Prune grid search parameters | ||
grid_search: | ||
prune_percent_list: [0.1, 0.2, 0.3] | ||
prune_method_list: ["mk_prune", "mk_prune_adjusted"] | ||
use_normalized_weights_list: [true, false] | ||
use_layer_norm_tweaks_list: [true, false] | ||
layer_norm_scale_list: [2.0, 4.0] | ||
target_size_list: [null, 512] | ||
gate_up_weight_weights_list: [[1.0, 1.0], [0.4, 0.6]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat" | ||
prune_percent: 0.2 | ||
dtype: "float32" | ||
cache_dir: null | ||
device: "cuda" | ||
output: "results/pruned_model" | ||
apply_chat_template: true | ||
prompt: "What is the capital of France?" | ||
max_new_tokens: 50 | ||
target_size: null | ||
log_dir: "logs" | ||
test_only: false | ||
prune_method: "mka" | ||
use_normalized_weights: false | ||
use_layer_norm_tweaks: true | ||
layer_norm_scale: 2.0 | ||
gate_up_weight_weights: [0.3, 0.7] | ||
print_summary: false | ||
quiet: false | ||
stop_logging: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat" | ||
prune_percent: 0.2 | ||
dtype: "float32" | ||
cache_dir: null | ||
device: "cuda" | ||
output: "results/pruned_model" | ||
apply_chat_template: true | ||
prompt: "What is the capital of France?" | ||
max_new_tokens: 50 | ||
target_size: null | ||
log_dir: "logs" | ||
test_only: true | ||
prune_method: "mka" | ||
use_normalized_weights: false | ||
use_layer_norm_tweaks: true | ||
layer_norm_scale: 5.0 | ||
gate_up_weight_weights: [0.3, 0.7] | ||
print_summary: false | ||
quiet: true | ||
stop_logging: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat" | ||
prune_percent: 0.2 | ||
dtype: "float32" | ||
cache_dir: null | ||
device: "cuda" | ||
output: "results/pruned_model" | ||
apply_chat_template: true | ||
prompt: "What is the capital of France?" | ||
max_new_tokens: 50 | ||
target_size: null | ||
log_dir: "logs" | ||
test_only: true | ||
prune_method: "mka" | ||
use_normalized_weights: false | ||
use_layer_norm_tweaks: true | ||
layer_norm_scale: 5.0 | ||
gate_up_weight_weights: [0.3, 0.7] | ||
print_summary: false | ||
quiet: false | ||
stop_logging: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, List | ||
from config.prune_grid_search_config import PruneGridSearchConfig | ||
|
||
import logging | ||
import os | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# Dataclass to store the prune configuration. | ||
@dataclass | ||
class PruneConfig: | ||
model_name: str = "meditsolutions/Llama-3.2-SUN-1B-chat" | ||
dtype: str = "float32" | ||
device: str = "cuda" | ||
output: str = "results/pruned_model" | ||
apply_chat_template: bool = False | ||
prompt: str = "What is the capital of France?" | ||
apply_chat_template: bool = False | ||
max_new_tokens: int = 50 | ||
prune_percent: float = 0.2 | ||
prune_method: str = "mk_prune" | ||
use_normalized_weights: bool = False | ||
use_layer_norm_tweaks: bool = False | ||
layer_norm_scale: float = 4.0 | ||
log_dir: str = "logs" | ||
stop_logging: bool = False | ||
test_only: bool = False | ||
print_summary: bool = False | ||
quiet: bool = False | ||
cache_dir: Optional[str] = None | ||
target_size: Optional[int] = None | ||
gate_up_weight_weights: Optional[List[float]] = field( | ||
default_factory=lambda: [1.0, 1.0] | ||
) | ||
eval_dataset: Optional[str] = None | ||
eval_dataset_size: Optional[int] = None | ||
grid_search: Optional[PruneGridSearchConfig] = None | ||
|
||
def __post_init__(self): | ||
# Check if the prune method is valid. | ||
if self.prune_method not in ["mk_prune", "mk", "mk_prune_adjusted", "mka"]: | ||
raise ValueError(f"Unknown prune method: {self.prune_method}") | ||
elif self.prune_method in ["mk", "mk_prune"]: | ||
self.prune_method = "mk_prune" | ||
elif self.prune_method in ["mka", "mk_prune_adjusted"]: | ||
self.prune_method = "mk_prune_adjusted" | ||
else: | ||
pass | ||
|
||
# Validate parameters | ||
if self.use_layer_norm_tweaks and self.layer_norm_scale <= 0: | ||
raise ValueError("layer_norm_scale must be greater than 0.") | ||
|
||
if self.gate_up_weight_weights is not None: | ||
if len(self.gate_up_weight_weights) != 2: | ||
raise ValueError("gate_up_weight_weights must have 2 values.") | ||
if any( | ||
[weight < 0 or weight > 1.0 for weight in self.gate_up_weight_weights] | ||
): | ||
raise ValueError("gate_up_weight_weights must be between 0 and 1") | ||
|
||
if self.target_size is not None: | ||
if self.target_size <= 0: | ||
raise ValueError("target_size must be greater than 0.") | ||
|
||
if self.prune_percent < 0 or self.prune_percent > 1.0: | ||
raise ValueError("prune_percent must be between 0 and 1") | ||
|
||
if self.max_new_tokens <= 10: | ||
raise ValueError("max_new_tokens must be greater than 10") | ||
|
||
if self.cache_dir is not None: | ||
if not os.path.exists(self.cache_dir): | ||
raise ValueError("cache_dir does not exist.") | ||
|
||
if not os.path.exists(self.output): | ||
os.makedirs(self.output) | ||
|
||
if self.eval_dataset is not None and self.grid_search is None: | ||
raise ValueError("Grid search (AutoML) requires an eval_dataset.") | ||
|
||
if self.eval_dataset is not None and self.eval_dataset_size is None: | ||
logger.warning("eval_dataset_size is not set. Defaulting to 20.") | ||
self.eval_dataset_size = 20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, List | ||
|
||
import itertools | ||
|
||
|
||
# Dataclass to store the prune grid search configuration. | ||
@dataclass | ||
class PruneGridSearchConfig: | ||
prune_percent_list: List[float] = field(default_factory=lambda: [0.2]) | ||
prune_method_list: List[str] = field(default_factory=lambda: ["mk_prune"]) | ||
use_normalized_weights_list: List[bool] = field(default_factory=lambda: [False]) | ||
use_layer_norm_tweaks_list: List[bool] = field(default_factory=lambda: [False]) | ||
layer_norm_scale_list: List[float] = field(default_factory=lambda: [4.0]) | ||
target_size_list: List[Optional[int]] = field(default_factory=lambda: [None]) | ||
gate_up_weight_weights_list: List[List[float]] = field( | ||
default_factory=lambda: [[1.0, 1.0]] | ||
) | ||
|
||
def __post_init__(self): | ||
self.generate_param_combinations() | ||
|
||
assert all( | ||
0.0 <= p <= 1.0 for p in self.prune_percent_list | ||
), "Prune percent must be between 0 and 1" | ||
assert all( | ||
m in ["mk_prune", "mk_prune_adjusted", "mk", "mka"] | ||
for m in self.prune_method_list | ||
), "Prune method must be 'mk_prune' or 'mk_prune_adjusted' or 'mk' or 'mka'" | ||
assert all( | ||
isinstance(u, bool) for u in self.use_normalized_weights_list | ||
), "Use normalized weights must be a boolean" | ||
assert all( | ||
isinstance(u, bool) for u in self.use_layer_norm_tweaks_list | ||
), "Use layer norm tweaks must be a boolean" | ||
assert all( | ||
0.0 <= s for s in self.layer_norm_scale_list | ||
), "Layer norm scale must be non-negative" | ||
assert all( | ||
t is None or t > 0 for t in self.target_size_list | ||
), "Target size must be None or a positive integer" | ||
assert all( | ||
all(0.0 <= w <= 1.0 for w in weights) | ||
for weights in self.gate_up_weight_weights_list | ||
), "Gate up weight weights must be between 0 and 1" | ||
|
||
@classmethod | ||
def generate_param_combinations(self): | ||
return itertools.product( | ||
self.prune_percent_list, | ||
self.prune_method_list, | ||
self.use_normalized_weights_list, | ||
self.use_layer_norm_tweaks_list, | ||
self.layer_norm_scale_list, | ||
self.target_size_list, | ||
self.gate_up_weight_weights_list, | ||
) |
Oops, something went wrong.