Skip to content

Commit

Permalink
Testcase for fix 10 (#11)
Browse files Browse the repository at this point in the history
Fixes #10
  • Loading branch information
glorialeezero authored Nov 25, 2024
2 parents 31f5021 + dc1c813 commit 7587179
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
18 changes: 9 additions & 9 deletions qupsy/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def depth(self) -> int:
)

@property
def filled(self) -> bool:
def terminated(self) -> bool:
for aexp in self.children:
if not aexp.filled:
if not aexp.terminated:
return False
return True

Expand Down Expand Up @@ -104,7 +104,7 @@ def children(self) -> list[Aexp]:
return []

@property
def filled(self) -> bool:
def terminated(self) -> bool:
return False


Expand Down Expand Up @@ -227,9 +227,9 @@ def depth(self) -> int:
)

@property
def filled(self) -> bool:
def terminated(self) -> bool:
for aexp in self.children:
if not aexp.filled:
if not aexp.terminated:
return False
return True

Expand All @@ -253,7 +253,7 @@ def children(self) -> list[Aexp]:
return []

@property
def filled(self) -> bool:
def terminated(self) -> bool:
return False


Expand Down Expand Up @@ -410,9 +410,9 @@ def depth(self) -> int:
)

@property
def filled(self) -> bool:
def terminated(self) -> bool:
for child in self.children:
if not child.filled:
if not child.terminated:
return False
return True

Expand All @@ -436,7 +436,7 @@ def children(self) -> list[Cmd | Gate | Aexp]:
return []

@property
def filled(self) -> bool:
def terminated(self) -> bool:
return False


Expand Down
16 changes: 8 additions & 8 deletions qupsy/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def visit_Aexp(self, aexp: Aexp) -> list[Aexp]:
return visitor(aexp)

for i, child in enumerate(aexp.children):
if child.filled:
if child.terminated:
continue
next_children = self.visit_Aexp(child)
pre_args = aexp.children[:i]
Expand All @@ -50,7 +50,7 @@ def visit_Aexp(self, aexp: Aexp) -> list[Aexp]:
*(a.copy() for a in post_args),
)
)

return ret
return []

def visit_HoleGate(self, gate: HoleGate) -> list[Gate]:
Expand All @@ -63,7 +63,7 @@ def visit_Gate(self, gate: Gate) -> list[Gate]:
return visitor(gate)

for i, child in enumerate(gate.children):
if child.filled:
if child.terminated:
continue
next_children = self.visit_Aexp(child)
pre_args = gate.children[:i]
Expand All @@ -84,13 +84,13 @@ def visit_HoleCmd(self, cmd: HoleCmd) -> list[Cmd]:
return [SeqCmd(), ForCmd(var=f"i{self.for_depth}"), GateCmd()]

def visit_SeqCmd(self, cmd: SeqCmd) -> list[Cmd]:
if not cmd.pre.filled:
if not cmd.pre.terminated:
pres = self.visit_Cmd(cmd.pre)
ret: list[Cmd] = []
for pre in pres:
ret.append(SeqCmd(pre=pre, post=cmd.post.copy()))
return ret
if not cmd.post.filled:
if not cmd.post.terminated:
posts = self.visit_Cmd(cmd.post)
ret: list[Cmd] = []
for post in posts:
Expand All @@ -99,7 +99,7 @@ def visit_SeqCmd(self, cmd: SeqCmd) -> list[Cmd]:
return []

def visit_ForCmd(self, cmd: ForCmd) -> list[Cmd]:
if not cmd.start.filled:
if not cmd.start.terminated:
starts = self.visit_Aexp(cmd.start)
ret: list[Cmd] = []
for start in starts:
Expand All @@ -112,7 +112,7 @@ def visit_ForCmd(self, cmd: ForCmd) -> list[Cmd]:
)
)
return ret
if not cmd.end.filled:
if not cmd.end.terminated:
ends = self.visit_Aexp(cmd.end)
ret: list[Cmd] = []
for end in ends:
Expand All @@ -125,7 +125,7 @@ def visit_ForCmd(self, cmd: ForCmd) -> list[Cmd]:
)
)
return ret
if not cmd.body.filled:
if not cmd.body.terminated:
self.for_depth += 1
bodies = self.visit_Cmd(cmd.body)
self.for_depth -= 1
Expand Down
17 changes: 17 additions & 0 deletions tests/test_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
ALL_AEXPS,
ALL_GATES,
CX,
Add,
Aexp,
Cmd,
ForCmd,
Gate,
GateCmd,
H,
HoleAexp,
Integer,
Pgm,
SeqCmd,
Expand Down Expand Up @@ -78,3 +80,18 @@ def test_hole_aexp3():
if type(pgm.body.gate.qreg2) not in [Integer, Var]:
aexp_types.remove(type(pgm.body.gate.qreg2))
assert len(aexp_types) == 3


def test_next_aexp():
pgm = Pgm("n", GateCmd(CX(Integer(0), Add())))
aexp_types: list[type[Aexp]] = ALL_AEXPS.copy()
pgms = next(pgm)
for pgm in pgms:
assert isinstance(pgm.body, GateCmd)
assert isinstance(pgm.body.gate, CX)
assert isinstance(pgm.body.gate.qreg2, Add)
assert type(pgm.body.gate.qreg2.a) in aexp_types
if type(pgm.body.gate.qreg2.a) not in [Integer, Var]:
aexp_types.remove(type(pgm.body.gate.qreg2.a))
assert isinstance(pgm.body.gate.qreg2.b, HoleAexp)
assert len(aexp_types) == 3

0 comments on commit 7587179

Please sign in to comment.