Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypy] nncf/common #3212

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nncf/api/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, TypeVar

from nncf.api.statistics import Statistics
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.api_marker import api
from nncf.common.utils.backend import copy_model

Expand Down Expand Up @@ -227,9 +227,9 @@ def compression_stage(self) -> CompressionStage:
"""

@abstractmethod
def statistics(self, quickly_collected_only: bool = False) -> Statistics:
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
"""
Returns a `Statistics` class instance that contains compression algorithm statistics.
Returns a `NNCFStatistics` class instance that contains compression algorithm statistics.

:param quickly_collected_only: Enables collection of the statistics that
don't take too much time to compute. Can be helpful for the case when
Expand Down
3 changes: 1 addition & 2 deletions nncf/common/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.api.compression import CompressionStage
from nncf.common.logging import nncf_logger
from nncf.common.plotting import noninteractive_plotting
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.helpers import configure_accuracy_aware_paths
from nncf.common.utils.tensorboard import prepare_for_tensorboard
from nncf.config.schemata.defaults import AA_COMPRESSION_RATE_STEP_REDUCTION_FACTOR
Expand Down Expand Up @@ -294,7 +293,7 @@ def train_epoch(self, model: TModel, compression_controller: CompressionAlgorith
self.cumulative_epoch_count += 1

def dump_statistics(self, model: TModel, compression_controller: CompressionAlgorithmController) -> None:
statistics = cast(NNCFStatistics, compression_controller.statistics())
statistics = compression_controller.statistics()

if self.verbose:
nncf_logger.info(statistics.to_str())
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/accuracy_aware_training/runner_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_training_loop(self) -> BaseAccuracyAwareTrainingRunner:

:return: AccuracyAwareTrainingRunner object
"""
nncf_backend = get_backend(self.compression_controller.model) # type: ignore
nncf_backend = get_backend(self.compression_controller.model)
if nncf_backend is BackendType.TORCH:
from nncf.torch.accuracy_aware_training.runner import PTAccuracyAwareTrainingRunner

Expand Down Expand Up @@ -114,7 +114,7 @@ def create_training_loop(self) -> BaseAdaptiveCompressionLevelTrainingRunner:

:return: AdaptiveCompressionLevelTrainingRunner object
"""
nncf_backend = get_backend(self.compression_controller.model) # type: ignore
nncf_backend = get_backend(self.compression_controller.model)

if nncf_backend is BackendType.TORCH:
from nncf.torch.accuracy_aware_training.runner import PTAdaptiveCompressionLevelTrainingRunner
Expand Down
55 changes: 27 additions & 28 deletions nncf/common/composite_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class CompositeCompressionLoss(CompressionLoss):
`CompressionLoss` instance.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._child_losses = []
self._child_losses: List[CompressionLoss] = []

@property
def child_losses(self) -> List[CompressionLoss]:
Expand All @@ -48,7 +48,7 @@ def add(self, child_loss: CompressionLoss) -> None:
"""
self._child_losses.append(child_loss)

def load_state(self, state: List[Dict[str, Any]]) -> None:
def load_state(self, state: List[Dict[str, Any]]) -> None: # type: ignore[override]
"""
Loads the composite compression loss state.

Expand All @@ -57,7 +57,7 @@ def load_state(self, state: List[Dict[str, Any]]) -> None:
for child_loss, child_state in zip(self._child_losses, state):
child_loss.load_state(child_state)

def get_state(self) -> List[Dict[str, Any]]:
def get_state(self) -> List[Dict[str, Any]]: # type: ignore[override]
"""
Returns the composite compression loss state.

Expand All @@ -68,7 +68,7 @@ def get_state(self) -> List[Dict[str, Any]]:
composite_state.append(child_loss.get_state())
return composite_state

def calculate(self, *args, **kwargs) -> Any:
def calculate(self, *args: Any, **kwargs: Any) -> Any:
"""
Traverses through all children and calculates the total compression
loss value.
Expand All @@ -92,9 +92,9 @@ class CompositeCompressionScheduler(CompressionScheduler):
`CompressionScheduler` instance.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._child_schedulers = []
self._child_schedulers: List[CompressionScheduler] = []

@property
def child_schedulers(self) -> List[CompressionScheduler]:
Expand Down Expand Up @@ -130,7 +130,7 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None:
for scheduler in self._child_schedulers:
scheduler.epoch_step(next_epoch)

def load_state(self, state: List[Dict[str, Any]]) -> None:
def load_state(self, state: List[Dict[str, Any]]) -> None: # type: ignore[override]
"""
Calls `load_state()` method for all children.

Expand All @@ -139,7 +139,7 @@ def load_state(self, state: List[Dict[str, Any]]) -> None:
for child_scheduler, child_state in zip(self._child_schedulers, state):
child_scheduler.load_state(child_state)

def get_state(self) -> List[Dict[str, Any]]:
def get_state(self) -> List[Dict[str, Any]]: # type: ignore[override]
"""
Returns the composite compression scheduler state. This state contains
the state of all children.
Expand Down Expand Up @@ -172,11 +172,11 @@ def __init__(self, target_model: TModel):
by the `CompressionAlgorithmBuilder`.
"""
super().__init__(target_model)
self._child_ctrls = []
self._child_ctrls: List[CompressionAlgorithmController] = []
self._loss = CompositeCompressionLoss()
self._scheduler = CompositeCompressionScheduler()
self._builder_state = None
self._name = None
self._builder_state: Optional[Dict[str, Any]] = None
self._name: Optional[str] = None

@property
def loss(self) -> CompressionLoss:
Expand All @@ -192,7 +192,9 @@ def child_ctrls(self) -> List[CompressionAlgorithmController]:

@property
def name(self) -> str:
raise self._name
if self._name is None:
raise nncf.InternalError("Internal error: algorithm name is not set for the controller")
return self._name

def add(self, child_ctrl: CompressionAlgorithmController) -> None:
"""
Expand All @@ -219,13 +221,10 @@ def compression_stage(self) -> CompressionStage:
"""
if not self.child_ctrls:
return CompressionStage.UNCOMPRESSED
result = None
for ctrl in self.child_ctrls:
result = self.child_ctrls[0].compression_stage()
for ctrl in self.child_ctrls[1:]:
current_level = ctrl.compression_stage()
if result is None:
result = current_level
else:
result += current_level
result += current_level
return result

def load_state(self, state: Dict[str, Dict[str, Any]]) -> None:
Expand Down Expand Up @@ -277,13 +276,13 @@ def prepare_for_export(self) -> None:
stripped_model = ctrl.strip_model(stripped_model)
self._model = stripped_model

def strip(self, do_copy: bool = True) -> TModel:
def strip(self, do_copy: bool = True) -> TModel: # type: ignore
model = self.model
if do_copy:
model = copy_model(model)
for ctrl in self.child_ctrls:
model = ctrl.strip_model(model, do_copy=False)
return model
return model # type: ignore

@property
def compression_rate(self) -> float:
Expand Down Expand Up @@ -329,12 +328,12 @@ def export_model(
if backend is BackendType.TENSORFLOW:
from nncf.tensorflow.exporter import TFExporter

exporter = TFExporter(self.model, input_names, output_names, model_args)
exporter = TFExporter(self.model, input_names, output_names, model_args) # type: ignore
else:
assert backend is BackendType.TORCH
from nncf.torch.exporter import PTExporter

exporter = PTExporter(self.model, input_names, output_names, model_args)
exporter = PTExporter(self.model, input_names, output_names, model_args) # type: ignore
if save_format is not None:
exporter.export_model(save_path, save_format)
else:
Expand All @@ -352,7 +351,7 @@ def get_compression_state(self) -> Dict[str, Any]:

return {self.BUILDER_STATE: self._builder_state, self.CONTROLLER_STATE: self.get_state()}

def set_builder_state_with_name(self, name: str, builder_state: Dict):
def set_builder_state_with_name(self, name: str, builder_state: Dict[str, Any]) -> None:
"""
Sets state of the builder and the corresponding algorithm name. Should be called by the builder to set its
state and registered algorithm key.
Expand Down Expand Up @@ -382,16 +381,16 @@ def __init__(self, config: NNCFConfig, should_init: bool = True):
"""
self._config = config
self.should_init = should_init
self._child_builders = []
self._child_builders: List[CompressionAlgorithmBuilder] = []

def _get_algo_specific_config_section(self) -> Dict:
def _get_algo_specific_config_section(self) -> Dict[str, Any]:
return {}

@property
def child_builders(self) -> List[CompressionAlgorithmBuilder]:
return self._child_builders

def load_state(self, state: Dict[str, Dict]) -> None:
def load_state(self, state: Dict[str, Dict[str, Any]]) -> None:
"""
Loads the compression builder state of children

Expand All @@ -400,7 +399,7 @@ def load_state(self, state: Dict[str, Dict]) -> None:
for builder in self.child_builders:
builder.load_state(state)

def get_state(self) -> Dict[str, Dict]:
def get_state(self) -> Dict[str, Dict[str, Any]]:
"""
Returns the composite compression builder state. This state contains
the state of all children.
Expand Down
22 changes: 11 additions & 11 deletions nncf/common/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ def __init__(self, target_model: TModel):
by the `CompressionAlgorithmBuilder`.
"""
super().__init__(target_model)
self._name = None
self._builder_state = None
self._name: Optional[str] = None
self._builder_state: Optional[Dict[str, Any]] = None

@property
def name(self):
def name(self) -> str:
if self._name is None:
raise nncf.InternalError("Internal error: name of the controller is not set!")
return self._name

@property
def compression_rate(self) -> float:
return None
return None # type: ignore

@compression_rate.setter
def compression_rate(self) -> float:
Expand Down Expand Up @@ -111,12 +111,12 @@ def export_model(
if backend is BackendType.TENSORFLOW:
from nncf.tensorflow.exporter import TFExporter

exporter = TFExporter(self.model, input_names, output_names, model_args)
exporter = TFExporter(self.model, input_names, output_names, model_args) # type: ignore
else:
assert backend is BackendType.TORCH
from nncf.torch.exporter import PTExporter

exporter = PTExporter(self.model, input_names, output_names, model_args)
exporter = PTExporter(self.model, input_names, output_names, model_args) # type: ignore
if save_format is not None:
exporter.export_model(save_path, save_format)
else:
Expand All @@ -125,7 +125,7 @@ def export_model(
def disable_scheduler(self) -> None:
self._scheduler = StubCompressionScheduler()

def set_builder_state_with_name(self, name: str, builder_state: Dict):
def set_builder_state_with_name(self, name: str, builder_state: Dict[str, Any]) -> None:
"""
Sets state of the builder and the corresponding algorithm name. Should be called by the builder to set its
state and registered algorithm key.
Expand All @@ -146,7 +146,7 @@ def load_state(self, state: Dict[str, Dict[str, Any]]) -> None:
algo_state = state[self.name]
if self._state_names.COMPRESSION_STAGE in state:
compression_stage = state[self._state_names.COMPRESSION_STAGE]
if self.compression_stage() != compression_stage:
if self.compression_stage() != compression_stage: # type: ignore
nncf_logger.warning(
f"Current CompressionStage ({self.compression_stage()}) of the compression controller "
f"does not correspond to the value found in the checkpoint ({compression_stage})"
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(self, config: NNCFConfig, should_init: bool = True):
if self.target_scopes is None:
self.target_scopes = algo_target_scopes

def _get_algo_specific_config_section(self) -> Dict:
def _get_algo_specific_config_section(self) -> Dict[str, Any]:
return extract_algo_specific_config(self.config, self.name)

@property
Expand Down Expand Up @@ -273,7 +273,7 @@ def build_controller(self, model: TModel) -> BaseCompressionAlgorithmController:
return ctrl

@abstractmethod
def _load_state_without_name(self, state_without_name: Dict[str, Any]):
def _load_state_without_name(self, state_without_name: Dict[str, Any]) -> None:
"""
Implementation of load state that takes state without builder name.

Expand All @@ -289,7 +289,7 @@ def _get_state_without_name(self) -> Dict[str, Any]:
(dict, list, tuple, str, int, float, True, False, None) that represents state of the object.
"""

def _parse_bn_adapt_params(self) -> Optional[Dict]:
def _parse_bn_adapt_params(self) -> Optional[Dict[str, Any]]:
try:
return extract_bn_adaptation_init_params(self.config, self.name)
except BNAdaptDataLoaderNotFoundError as e:
Expand Down
14 changes: 7 additions & 7 deletions nncf/common/quantization/config_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Dict, List
from typing import Any, Dict, List, Optional

from nncf.common.graph import NNCFNode
from nncf.common.hardware.config import HWConfig
Expand All @@ -19,7 +19,7 @@


def get_scoped_quantizer_config(
base_config: QuantizerConfig, scope_str: str, scope_overrides: Dict = None
base_config: QuantizerConfig, scope_str: str, scope_overrides: Optional[Dict[str, Any]] = None
) -> QuantizerConfig:
"""
Returns a QuantizerConfig which is based on a given config, which will have overrides
Expand Down Expand Up @@ -54,8 +54,8 @@ def assign_qconfig_lists_to_modules(
nodes_with_weights: List[NNCFNode],
default_weight_qconfig: QuantizerConfig,
global_weight_constraints: QuantizationConstraints = None,
scope_overrides_dict: Dict = None,
hw_config: HWConfig = None,
scope_overrides_dict: Optional[Dict[str, Any]] = None,
hw_config: Optional[HWConfig] = None,
) -> Dict[NNCFNode, List[QuantizerConfig]]:
"""
Assigns a list of possible quantizer configurations (as determined by HW config, defaults and overrides)
Expand Down Expand Up @@ -89,7 +89,7 @@ def assign_qconfig_lists_to_modules(
qconfig_list = [qconfig_for_current_scope]
else:
metatype = node.metatype
qconfig_list = meta_vs_qconfig_map[metatype]
qconfig_list = meta_vs_qconfig_map[metatype] # type: ignore
if HWConfig.is_wildcard_quantization(qconfig_list): # Empty list = wildcard quantization
qconfig_list = [default_qconfig]
elif HWConfig.is_qconf_list_corresponding_to_unspecified_op(qconfig_list):
Expand All @@ -99,8 +99,8 @@ def assign_qconfig_lists_to_modules(
for overridden_scope, scoped_override_dict in scope_overrides_dict.items():
if matches_any(node.node_name, overridden_scope):
scope_constraints = QuantizationConstraints.from_config_dict(scoped_override_dict)
local_constraints = local_constraints.get_updated_constraints(scope_constraints)
qconfig_list = local_constraints.constrain_qconfig_list(
local_constraints = local_constraints.get_updated_constraints(scope_constraints) # type: ignore
qconfig_list = local_constraints.constrain_qconfig_list( # type: ignore
node.node_name, hw_config.target_device, qconfig_list
)

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def create_quantizer_setup(
)
for pq_set in pq_sets_grouped_by_unified_scale:
setup.register_unified_scale_group_with_types(
[pqid_vs_qpid[pq.id] for pq in pq_set], [pq.unified_scale_type for pq in pq_set]
[pqid_vs_qpid[pq.id] for pq in pq_set], [pq.unified_scale_type for pq in pq_set] # type: ignore
)

setup = self._handle_output_quantizers_for_weights_as_outputs_ops(setup, pqid_vs_qpid, wao_op_node_key_vs_wq_id)
Expand Down
Loading