Skip to content

Commit

Permalink
Fixed RunMode and Handler (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Dec 10, 2021
1 parent b4b99dd commit 8e7a560
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ You can always go to those files and change the initial configuration.
import neptune.new as neptune
```

* Add **neptune\_run** argument of type `neptune.run.Handler` to the `report_accuracy` function
* Add **neptune\_run** argument of type `neptune.handler.Handler` to the `report_accuracy` function

```python
def report_accuracy(predictions: np.ndarray, test_y: pd.DataFrame,
neptune_run: neptune.run.Handler) -> None:
neptune_run: neptune.handler.Handler) -> None:
...
```

Expand All @@ -158,7 +158,7 @@ You have to use a special string "**neptune\_run"** to use the Neptune Run handl

```python
def report_accuracy(predictions: np.ndarray, test_y: pd.DataFrame,
neptune_run: neptune.run.Handler) -> None:
neptune_run: neptune.handler.Handler) -> None:
target = np.argmax(test_y.to_numpy(), axis=1)
accuracy = np.sum(predictions == target) / target.shape[0]

Expand All @@ -171,7 +171,7 @@ You can log metadata from any node to any [Neptune namespace](../../you-should-k

```python
def report_accuracy(predictions: np.ndarray, test_y: pd.DataFrame,
neptune_run: neptune.run.Handler) -> None:
neptune_run: neptune.handler.Handler) -> None:
target = np.argmax(test_y.to_numpy(), axis=1)
accuracy = np.sum(predictions == target) / target.shape[0]

Expand Down
2 changes: 1 addition & 1 deletion examples/planets/src/planets/pipelines/moons_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, trial: optuna.trial.Trial):
return probs[0]


def optimize(neptune_run: neptune.run.Handler, model: fastai.tabular.model.TabularModel):
def optimize(neptune_run: neptune.handler.Handler, model: fastai.tabular.model.TabularModel):
study = optuna.create_study(direction="minimize")
study.optimize(
Objective(neptune_run=neptune_run._run, model=model),
Expand Down
20 changes: 9 additions & 11 deletions kedro_neptune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,12 @@
import neptune.new as neptune
from neptune.new.types import File
from neptune.new.internal.utils import verify_type
from neptune.new.internal.init_impl import RunMode
from neptune.new.internal.utils.paths import join_paths
except ImportError:
# neptune-client>=1.0.0 package structure
import neptune
from neptune.types import File
from neptune.internal.utils import verify_type
from neptune.internal.init_impl import RunMode
from neptune.internal.utils.paths import join_paths

INTEGRATION_VERSION_KEY = 'source_code/integrations/kedro-neptune'
Expand Down Expand Up @@ -211,7 +209,7 @@ def init(metadata: ProjectMetadata, api_token: str, project: str, base_namespace


def _connection_mode(enabled: bool) -> str:
return RunMode.ASYNC if enabled else RunMode.DEBUG
return 'async' if enabled else 'debug'


class NeptuneRunDataSet(AbstractDataSet):
Expand All @@ -224,7 +222,7 @@ def _describe(self) -> Dict[str, Any]:
def _exists(self) -> bool:
return True

def _load(self) -> neptune.run.Handler:
def _load(self) -> neptune.handler.Handler:
config = get_neptune_config()

run = neptune.init(api_token=config.api_token,
Expand Down Expand Up @@ -329,7 +327,7 @@ def __init__(
)


def log_file_dataset(namespace: neptune.run.Handler, name: str, dataset: NeptuneFileDataSet):
def log_file_dataset(namespace: neptune.handler.Handler, name: str, dataset: NeptuneFileDataSet):
# pylint: disable=protected-access
if not namespace._run.exists(f'{namespace._path}/{name}'):
data = dataset.load()
Expand All @@ -348,12 +346,12 @@ def log_file_dataset(namespace: neptune.run.Handler, name: str, dataset: Neptune
)


def log_parameters(namespace: neptune.run.Handler, catalog: DataCatalog):
def log_parameters(namespace: neptune.handler.Handler, catalog: DataCatalog):
# pylint: disable=protected-access
namespace['parameters'] = catalog._data_sets['parameters'].load()


def log_dataset_metadata(namespace: neptune.run.Handler, name: str, dataset: AbstractDataSet):
def log_dataset_metadata(namespace: neptune.handler.Handler, name: str, dataset: AbstractDataSet):
additional_parameters = {}
try:
# pylint: disable=protected-access
Expand All @@ -368,7 +366,7 @@ def log_dataset_metadata(namespace: neptune.run.Handler, name: str, dataset: Abs
}


def log_data_catalog_metadata(namespace: neptune.run.Handler, catalog: DataCatalog):
def log_data_catalog_metadata(namespace: neptune.handler.Handler, catalog: DataCatalog):
# pylint: disable=protected-access
namespace = namespace['catalog']

Expand All @@ -383,7 +381,7 @@ def log_data_catalog_metadata(namespace: neptune.run.Handler, catalog: DataCatal
log_parameters(namespace=namespace, catalog=catalog)


def log_pipeline_metadata(namespace: neptune.run.Handler, pipeline: Pipeline):
def log_pipeline_metadata(namespace: neptune.handler.Handler, pipeline: Pipeline):
namespace['structure'].upload(File.from_content(
json.dumps(
json.loads(pipeline.to_json()),
Expand All @@ -394,11 +392,11 @@ def log_pipeline_metadata(namespace: neptune.run.Handler, pipeline: Pipeline):
))


def log_run_params(namespace: neptune.run.Handler, run_params: Dict[str, Any]):
def log_run_params(namespace: neptune.handler.Handler, run_params: Dict[str, Any]):
namespace['run_params'] = run_params


def log_command(namespace: neptune.run.Handler):
def log_command(namespace: neptune.handler.Handler):
namespace['kedro_command'] = ' '.join(['kedro'] + sys.argv[1:])


Expand Down

0 comments on commit 8e7a560

Please sign in to comment.