Skip to content

Commit

Permalink
Support filtering for fairy piece positions
Browse files Browse the repository at this point in the history
  • Loading branch information
ianfab committed Jul 22, 2023
1 parent a8bf5fe commit 4938027
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uci


def generate_fens(engine, variant, min_depth, max_depth, add_move):
def generate_fens(engine, variant, min_depth, max_depth, add_move, required_pieces):
if variant not in sf.variants():
raise Exception("Unsupported variant: {}".format(variant))

Expand All @@ -31,13 +31,13 @@ def generate_fens(engine, variant, min_depth, max_depth, add_move):
bestmove = None
else:
fen = sf.get_fen(variant, start_fen, move_stack[:-1])
if (fen, bestmove) not in fens:
if (fen, bestmove) not in fens and (not required_pieces or any(p in fen.split(' ')[0].lower() for p in required_pieces.lower())):
fens.add((fen, bestmove))
yield fen, bestmove


def write_fens(stream, engine, variant, count, min_depth, max_depth, add_move):
generator = generate_fens(engine, variant, min_depth, max_depth, add_move)
def write_fens(stream, engine, variant, count, min_depth, max_depth, add_move, required_pieces):
generator = generate_fens(engine, variant, min_depth, max_depth, add_move, required_pieces)
for _ in tqdm(range(count)):
fen, move = next(generator)
stream.write('{};variant {}'.format(fen, variant) + (';sm {}'.format(move) if move else '') + os.linesep)
Expand All @@ -54,11 +54,12 @@ def write_fens(stream, engine, variant, count, min_depth, max_depth, add_move):
parser.add_argument('-d', '--max-depth', type=int, default=6, help='maximum search depth')
parser.add_argument('-m', '--min-depth', type=int, default=1, help='minimum search depth')
parser.add_argument('-a', '--add-move', action='store_true', help='add initial move for opposing side')
parser.add_argument('-p', '--pieces', default=None, help='only return positions containing one of these piece chars (case insensitive)')
args = parser.parse_args()

ucioptions = dict(args.ucioptions)
ucioptions.update({'Skill Level': args.skill_level})

engine = uci.Engine([args.engine], ucioptions)
sf.set_option("VariantPath", engine.options.get("VariantPath", ""))
write_fens(sys.stdout, engine, args.variant, args.count, args.min_depth, args.max_depth, args.add_move)
write_fens(sys.stdout, engine, args.variant, args.count, args.min_depth, args.max_depth, args.add_move, args.pieces)

0 comments on commit 4938027

Please sign in to comment.