Skip to content

Commit

Permalink
Fixed examples
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Jan 14, 2025
1 parent 07126f1 commit a7d7bd3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 13 deletions.
8 changes: 8 additions & 0 deletions examples/advanced/ex02_shuffle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
"Be_snapshot1.out.npy",
data_path,
)
# On-the-fly snapshots can be added as well.
# data_shuffler.add_snapshot(
# "Be_snapshot2.info.json",
# data_path,
# "Be_snapshot2.out.npy",
# data_path,
# )


# Shuffle the snapshots using the "shuffle_to_temporary" option.
data_shuffler.shuffle_snapshots(
Expand Down
24 changes: 17 additions & 7 deletions examples/basic/ex01_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,25 @@
data_handler.add_snapshot(
"Be_snapshot0.in.npy", data_path, "Be_snapshot0.out.npy", data_path, "tr"
)
# Add snapshots with "raw" (=MALA formatted) JSON, computation of descriptors
# will be performed "on-the-fly".
data_handler.add_snapshot(
"Be_snapshot1.info.json",
data_path,
"Be_snapshot1.out.npy",
data_path,
"va",
"Be_snapshot1.in.npy", data_path, "Be_snapshot1.out.npy", data_path, "va"
)
# Add snapshots with "raw" (=MALA formatted) JSON, computation of descriptors
# will be performed "on-the-fly".
# data_handler.add_snapshot(
# "Be_snapshot0.info.json",
# data_path,
# "Be_snapshot0.out.npy",
# data_path,
# "tr",
# )
# data_handler.add_snapshot(
# "Be_snapshot1.info.json",
# data_path,
# "Be_snapshot1.out.npy",
# data_path,
# "va",
# )

data_handler.prepare_data()

Expand Down
24 changes: 20 additions & 4 deletions examples/basic/ex02_test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,33 @@
"te",
calculation_output_file=os.path.join(data_path, "Be_snapshot2.info.json"),
)

# Add snapshots with "raw" (=MALA formatted) JSON, computation of descriptors
# will be performed "on-the-fly".
data_handler.add_snapshot(
"Be_snapshot3.info.json",
"Be_snapshot3.in.npy",
data_path,
"Be_snapshot3.out.npy",
data_path,
"te",
calculation_output_file=os.path.join(data_path, "Be_snapshot3.info.json"),
)

# Add snapshots with "raw" (=MALA formatted) JSON, computation of descriptors
# will be performed "on-the-fly".
# data_handler.add_snapshot(
# "Be_snapshot2.info.json",
# data_path,
# "Be_snapshot2.out.npy",
# data_path,
# "te",
# calculation_output_file=os.path.join(data_path, "Be_snapshot2.info.json"),
# )
# data_handler.add_snapshot(
# "Be_snapshot3.info.json",
# data_path,
# "Be_snapshot3.out.npy",
# data_path,
# "te",
# calculation_output_file=os.path.join(data_path, "Be_snapshot3.info.json"),
# )
data_handler.prepare_data(reparametrize_scaler=False)


Expand Down
19 changes: 17 additions & 2 deletions examples/basic/ex03_preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
# more convenient *.json files that can be used in their stead. This saves
# on disk space and makes the process more reproducible.
# To only process parts of the data, omit/add descriptor_input*, target_input_*
# and simulation_output_* at your leisure.
# and simulation_output_* at your leisure. This is especially useful if you,
# e.g., do not need to convert the descriptor data, since it will be
# calculated on-the-fly during training.
# Make sure to set the correct units - for QE, this should always be
# 1/(Ry*Bohr^3).
####################
Expand All @@ -60,6 +62,7 @@
outfile = os.path.join(data_path, "Be_snapshot0.out")
ldosfile = os.path.join(data_path, "cubes/tmp.pp*Be_ldos.cube")

# Converting a snapshot for training on precomputed descriptor data.
data_converter.add_snapshot(
descriptor_input_type="espresso-out",
descriptor_input_path=outfile,
Expand All @@ -70,6 +73,16 @@
target_units="1/(Ry*Bohr^3)",
)

# Converting a snapshot for training with on-the-fly descriptor calculation.
# data_converter.add_snapshot(
# target_input_type=".cube",
# target_input_path=ldosfile,
# simulation_output_type="espresso-out",
# simulation_output_path=outfile,
# target_units="1/(Ry*Bohr^3)",
# )


####################
# 3. Converting the data
# To convert the data we now simply have to call the convert_snapshot function.
Expand All @@ -82,9 +95,11 @@
####################

data_converter.convert_snapshots(
descriptor_save_path="./",
target_save_path="./",
simulation_output_save_path="./",
# The next line should be omitted, if the descriptor data is to be
# calculated on-the-fly during training.
descriptor_save_path="./",
naming_scheme="Be_snapshot*.npy",
descriptor_calculation_kwargs={"working_directory": data_path},
)
Expand Down

0 comments on commit a7d7bd3

Please sign in to comment.