Skip to content

Commit

Permalink
Fixes for numpy 2, and more
Browse files Browse the repository at this point in the history
Fixed issues from running with numpy 2. Also cleaned up some linting issues.
  • Loading branch information
iamsrp-deshaw committed Nov 6, 2024
1 parent be911f2 commit 1fa71f5
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions python/pjrmi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def _send(self, msg_type, payload):
internally.
"""

assert(payload is not None)
assert payload is not None

# We can't currently send messages larger than 2GB in size. This will
# require a number of changes on both sides, partly since Java arrays
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def _recv(self):
payload = b''
while len(payload) < payload_size:
payload += self._transport.recv(payload_size - len(payload))
assert(len(payload) == payload_size)
assert len(payload) == payload_size

# See if it happened to be a callback
if request_id == self._CALLBACK_REQUEST_ID:
Expand Down Expand Up @@ -2533,7 +2533,7 @@ def _format_string(self, string):
Format a string as [int32:size][bytes[]:string], to send down the wire
"""

if type(string) != bytes:
if type(string) is not bytes:
string = string.encode('ascii')
return self._format_int32(len(string)) + string

Expand Down Expand Up @@ -2625,10 +2625,9 @@ def _format_array(self, value, dtype):
if len(arr) > self._MAX_JAVA_ARRAY_SIZE:
raise TypeError('The given array is larger than Java can represent')

return bytes(numpy.array(strict_array(dtype.type, arr),
dtype=dtype,
copy=False,
order='C').data)
return bytes(numpy.asarray(strict_array(dtype.type, arr),
dtype=dtype,
order='C').data)


def _format_method_as(self, value, klass):
Expand Down Expand Up @@ -2682,7 +2681,7 @@ def _format_as_lambda(self,
# argument list, or none at all
if not method._is_static:
if (len(args) == len(method._argument_type_ids) + 1 and
type(args[0]) == method._klass):
type(args[0]) is method._klass):
this = args[0]
args = args[1:]
else:
Expand Down Expand Up @@ -3078,17 +3077,17 @@ def _format_by_class(self, klass, value,
for el in value))

elif isinstance(value, int):
if numpy.can_cast(value, numpy.int8):
if -128 <= value < 128:
return (self._ARGUMENT_VALUE +
self._format_int32(self._java_lang_Byte._type_id) +
self._format_int8(strict_number(numpy.int8, value)))

elif numpy.can_cast(value, numpy.int16):
elif -16384 <= value < 16384:
return (self._ARGUMENT_VALUE +
self._format_int32(self._java_lang_Short._type_id) +
self._format_int16(strict_number(numpy.int16, value)))

elif numpy.can_cast(value, numpy.int32):
elif -2147483648 <= value < 2147483648:
return (self._ARGUMENT_VALUE +
self._format_int32(self._java_lang_Integer._type_id) +
self._format_int32(strict_number(numpy.int32, value)))
Expand Down Expand Up @@ -3254,7 +3253,7 @@ def _format_by_class(self, klass, value,
# If the user has given us a numpy type then we assume that they
# know what they are doing when it comes to types, and we
# disallow a lossy conversion.
if type(value) == numpy.float64:
if type(value) is numpy.float64:
raise ValueError("%s is not assignable to %s" %
(type(value), klass._classname))
else:
Expand Down Expand Up @@ -3542,7 +3541,7 @@ def _format_by_class(self, klass, value,
self._format_int32(self._get_object_id(value)))

elif (klass._type_id == self._com_deshaw_hypercube_Hypercube._type_id and
type(value) == numpy.ndarray):
type(value) is numpy.ndarray):
return (self._ARGUMENT_VALUE +
self._format_int32(klass._type_id) +
self._format_by_class(self._L_java_lang_long, value.shape) +
Expand Down

0 comments on commit 1fa71f5

Please sign in to comment.