Skip to content

Commit

Permalink
Merge branch 'main' into better_direct
Browse files Browse the repository at this point in the history
  • Loading branch information
giadarol committed Dec 19, 2024
2 parents dac39e0 + 73c9d3d commit 654d1a0
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
CIBW_PROJECT_REQUIRES_PYTHON: ">=3.7"
CIBW_PROJECT_REQUIRES_PYTHON: ">=3.8"

- uses: actions/upload-artifact@v4
with:
Expand Down
85 changes: 75 additions & 10 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@

from xdeps import Table

data = {
"name": np.array(["ip1", "ip2", "ip2", "ip3", "tab$end"]),
"s": np.array([1.0, 2.0, 2.1, 3.0, 4.0]),
"betx": np.array([4.0, 5.0, 5.1, 6.0, 7.0]),
"bety": np.array([2.0, 3.0, 3.1, 4.0, 9.0]),
}

def get_a_table():
data = {
"name": np.array(["ip1", "ip2", "ip2", "ip3", "tab$end"]),
"s": np.array([1.0, 2.0, 2.1, 3.0, 4.0]),
"betx": np.array([4.0, 5.0, 5.1, 6.0, 7.0]),
"bety": np.array([2.0, 3.0, 3.1, 4.0, 9.0]),
}
t = Table(data)
return t, data

## Table tests

t = Table(data)

t, data = get_a_table()

def test_table_initialization():
# Valid initialization
Expand Down Expand Up @@ -117,6 +117,48 @@ def test_table_getitem_edge_cases():
assert t[("betx",)][2] == t["betx"][2]


def test_table_setitem_col():
t, data = get_a_table()
t["2betx"] = t["betx"] * 2
assert np.array_equal(t["2betx"], data["betx"] * 2)
t["betx"] = 1
assert np.array_equal(t["betx"], np.ones(len(data["betx"])))
t["name"] = t["name"] * 2
assert np.all(t['name'] == [x * 2 for x in data['name']])


def test_table_setitem_col_row():
t, data = get_a_table()
t["betx", 1] = 10
assert t["betx", 1] == 10
t["betx", "ip2"] = 20
assert t["betx", "ip2"] == 20
t["betx", "ip2::1"] = 30
assert t["betx", "ip2::1"] == 30
t["betx", "ip2<<1"] = 40
assert t["betx", "ip2<<1"] == 40
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50
t["betx", "ip2::1>>1"] = 50
assert t["betx", "ip2::1>>1"] == 50


def test_table_numpy_string():
tab = Table(dict(name=np.array(["a", "b$b"]), val=np.array([1, 2])))
assert tab["val", tab.name[1]] == 2
Expand Down Expand Up @@ -242,6 +284,29 @@ def test_table_show():
assert table.show(output=str) == "name value\na 1\nb 2\nc 3"


def test_table_show_rows():
data = {"name": np.array(["a", "b", "c"]), "value": np.array([1, 2, 3])}
table = Table(data)
assert table.show(rows=1, output=str) == "name value\nb 2"


def test_table_show_cols():
data = {"name": np.array(["a", "b", "c"]), "value": np.array([1, 2, 3])}
table = Table(data)
assert (
table.show(cols="value", output=str)
== "name value\na 1\nb 2\nc 3"
)


def test_table_show_rows_cols():
t, data = get_a_table()
assert (
t.show(rows="ip2.*", cols="betx", output=str)
== "name betx\nip1 5\nip2::0 5.1"
)


## Table cols tests


Expand Down
2 changes: 1 addition & 1 deletion xdeps/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.2'
__version__ = '0.8.3'
48 changes: 39 additions & 9 deletions xdeps/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ def _get_row_indices(self, row):
return np.array(out, dtype=int)
else:
raise ValueError(f"Invalid row selector {row}")
elif row is None:
return slice(None)
else:
return [self._get_row_index(row)]

Expand Down Expand Up @@ -477,7 +479,10 @@ def _select(self, rows, cols):
col_list.insert(0, self._index)
data = {}
for cc in col_list:
data[cc] = eval(cc, gblmath, view)
try:
data[cc] = view[cc]
except KeyError:
data[cc] = eval(cc, gblmath, view)
for kk in self.keys(exclude_columns=True):
data[kk] = self._data[kk]
return self.__class__(
Expand Down Expand Up @@ -543,7 +548,7 @@ def __getitem__(self, args):
return self._data[args]
except KeyError:
return eval(args, gblmath, self._data)
if type(args) is tuple: # multiple args
if isinstance(args,tuple): # multiple args
if len(args) == 0:
col = None
row = None
Expand Down Expand Up @@ -824,14 +829,39 @@ def __setitem__(self, key, val):
object.__setattr__(self, "_index_cache", None)
object.__setattr__(self, "_count_cache", None)
object.__setattr__(self, "_name_cache", None)
if key in self.__dict__:
object.__setattr__(self, key, val)
elif key in self._col_names:
self._data[key][:] = val

if isinstance(key, str):
if key in self.__dict__:
object.__setattr__(self, key, val)
elif key in self._col_names:
self._data[key][:] = val
else:
self._data[key] = val
if hasattr(val, "__iter__") and len(val) == len(self):
self._col_names.append(key)
elif isinstance(key,tuple):
col,row = key
col = self._data[col]
if isinstance(row, str):
cache, count = self._get_cache()
idx = cache.get((row, 0))
if idx is None:
name, count, offset = self._split_name_count_offset(row)
idx = self._get_row_cache_raise(name, count, offset)
elif isinstance(row, tuple):
cache, count = self._get_cache()
idx = cache.get(row)
if idx is None:
idx = self._get_row_cache_raise(*row)
elif isinstance(row, slice):
idx = self._get_row_indices(row)
elif isinstance(row, list):
idx = self._get_row_indices(row)
else:
idx = row
col[idx]=val
else:
self._data[key] = val
if hasattr(val, "__iter__") and len(val) == len(self):
self._col_names.append(key)
raise ValueError(f"Invalid key {key}")

__setattr__ = __setitem__

Expand Down

0 comments on commit 654d1a0

Please sign in to comment.