Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jitted Function Save/Load #229

Open
xiazhuo opened this issue Jan 3, 2025 · 1 comment
Open

Jitted Function Save/Load #229

xiazhuo opened this issue Jan 3, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@xiazhuo
Copy link

xiazhuo commented Jan 3, 2025

Issue Description

I am facing an issue when trying to save and load JIT-compiled functions using tensorcircuit.keras.save_func() and tensorcircuit.keras.load_func(). Specifically, I am trying to save and load the qpred (or qlayer) function in my hybrid model, but I encounter the following error when trying to load the function:

  File "/home/.miniconda3/envs/qml/lib/python3.10/site-packages/tensorcircuit/keras.py", line 284, in wrapper  *
    return m.f(*args, **kws)
AttributeError: '_UserObject' object has no attribute 'f'

Here is the code that I am working with:

class HybridModel(torch.nn.Module): 
    def __init__(self, trunk_size, n_layers=2, n_hidden_layers=4, n_wires=2):
        super().__init__()
        K = tc.set_backend("tensorflow")
        tf_device = "/gpu"

        @tf.function
        def qpred(inputs, weights):
            with tf.device(tf_device):
                c = circuit(inputs, weights, trunk_size)
                observables = K.stack([K.real(c.expectation_ps(z=[i]))
                                       for i in range(n_wires)])
                return observables

        self.qpred = qpred
        self.qlayer = tc.TorchLayer(
            self.qpred, weights_shape=[2*n_layers, n_hidden_layers, n_wires, 2], use_jit=True, enable_dlpack=True)
        self.clayer = torch.nn.Linear(n_wires, 1)

    def forward(self, inputs):
        outputs = self.qlayer(inputs)
        outputs = torch.mean(outputs, axis=1)
        return outputs

What I have tried:

  • I have attempted to use tensorcircuit.keras.save_func() and tensorcircuit.keras.load_func() to save and load the function qpred or qlayer, but it results in the above error.

I am wondering if there is a different approach to saving/loading the JIT-compiled function, or if there is a potential issue with the way TensorCircuit handles saved functions in this context.

Would you be able to provide guidance or suggest an alternative solution for saving/loading the function, especially one that involves JIT compilation?

Thank you very much for your time and assistance. I appreciate any help you can provide!

Environment Context

OS info: Linux-5.4.0-150-generic-x86_64-with-glibc2.27
Python version: 3.10.14
Numpy version: 1.26.4
Scipy version: 1.12.0
Pandas version: 2.2.2
TensorNetwork version: 0.5.1
Cotengra version: 0.6.2
TensorFlow version: 2.18.0
TensorFlow GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:4', device_type='GPU')]
TensorFlow CUDA infos: {'cpu_compiler': '/usr/lib/llvm-18/bin/clang', 'cuda_compute_capabilities': ['sm_60', 'sm_70', 'sm_80', 'sm_89', 'compute_90'], 'cuda_version': '12.5.1', 'cudnn_version': '9', 'is_cuda_build': True, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.23
Jax GPU: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4)]
JaxLib version: 0.4.23
PyTorch version: 2.5.1+cu124
PyTorch GPU support: True
PyTorch GPUs: [<torch.cuda.device object at 0x7fec328bd8d0>, <torch.cuda.device object at 0x7fec328bd900>, <torch.cuda.device object at 0x7fec328bd8a0>, <torch.cuda.device object at 0x7fec328bdc30>, <torch.cuda.device object at 0x7fec328bdc90>]
Pytorch cuda version: 12.4
Cupy is not installed
Qiskit version: 1.3.1
Cirq version: 1.4.1
TensorCircuit version 0.12.0

@xiazhuo xiazhuo added the bug Something isn't working label Jan 3, 2025
@refraction-ray
Copy link
Contributor

refraction-ray commented Jan 5, 2025

work for me to save the tensorflow jitted function at least for cpu.

tf_device = "/cpu"
n_wires = 8

def circuit(inputs, weights):
    c = tc.Circuit(n_wires, inputs=inputs)
    c.rz(range(n_wires), theta=weights)
    return c

@tf.function
def qpred(inputs, weights):
    with tf.device(tf_device):
        c = circuit(inputs, weights)
        observables = K.stack([K.real(c.expectation_ps(x=[i]))
                                for i in range(n_wires)])
        return observables

print(qpred(K.ones([2**n_wires])/K.cast(K.sqrt(2.0**n_wires), "complex64"), 0.3*K.real(K.ones([n_wires]))))

tc.keras.save_func(qpred, "tempsave")
f = tc.keras.load_func("tempsave")
print(f(K.ones([2**n_wires])/K.cast(K.sqrt(2.0**n_wires), "complex64"), 0.3*K.real(K.ones([n_wires]))))

The model instance cannot be saved via tensorflow tools as the model is an instance of torch model.

For jax jitted function, please refer to https://jax.readthedocs.io/en/latest/export/export.html for IO.

Updated: I have also implemented function wrappers in TensorCircuit-NG to save/load jitted jax function: please see https://tensorcircuit-ng.readthedocs.io/en/latest/advance.html#jitted-function-save-load

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants