-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c85d72b
commit 95c06c8
Showing
12 changed files
with
260 additions
and
239 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import io | ||
|
||
import fitz as pymupdf | ||
import tempfile | ||
from bs4 import BeautifulSoup | ||
|
||
from marker.converters.pdf import PdfConverter | ||
|
||
def open_pymupdf(pdf_bytes): | ||
stream = io.BytesIO(pdf_bytes) | ||
return pymupdf.open(stream=stream) | ||
|
||
def clip_pdf_to_bbox(doc, bbox, padding=1): | ||
page = doc[0] | ||
height, width = page.bound().height, page.bound().width | ||
remove_left = [0, 0, bbox[0] - padding, height] | ||
remove_top = [0, 0, width, bbox[1] - padding] | ||
remove_right = [bbox[2] + padding, 0, width, height] | ||
remove_bottom = [0, bbox[3] + padding, width, height] | ||
for remove in [remove_left, remove_top, remove_right, remove_bottom]: | ||
clip_rect = pymupdf.Rect(*remove) | ||
page.add_redact_annot(clip_rect) | ||
page.apply_redactions() | ||
|
||
clip_rect = pymupdf.Rect(*bbox) | ||
page.set_cropbox(clip_rect) | ||
return doc | ||
|
||
def get_marker_block_html(marker_models: dict, gt_blocks: list, pdf_bytes: bytes): | ||
block_html = [] | ||
for block in gt_blocks: | ||
bbox = block["bbox"] | ||
doc2 = open_pymupdf(pdf_bytes) | ||
clip_pdf_to_bbox(doc2, bbox) | ||
block_converter = PdfConverter( | ||
artifact_dict=marker_models, | ||
config={"page_range": [0], "force_layout_block": block["block_type"], "disable_tqdm": True}, | ||
renderer="marker.renderers.html.HTMLRenderer" | ||
) | ||
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f: | ||
doc2.save(f) | ||
rendered = block_converter(f.name) | ||
html = rendered.html | ||
soup = BeautifulSoup(html, "html.parser") | ||
inner_html = str(soup.find("body").decode_contents()) | ||
block_html.append(inner_html) | ||
return block_html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import json | ||
import os | ||
from collections import defaultdict | ||
from pathlib import Path | ||
|
||
import click | ||
import datasets | ||
import tabulate | ||
from tqdm import tqdm | ||
|
||
from marker.logger import configure_logging | ||
from marker.models import create_model_dict | ||
from inference import get_marker_block_html | ||
from marker.settings import settings | ||
from scoring import score_blocks | ||
|
||
configure_logging() | ||
|
||
@click.command(help="Benchmark PDF to MD conversion.") | ||
@click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark") | ||
@click.option("--other_methods", type=str, help="Comma separated list of other methods to compare against. Possible values:", default="") | ||
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.") | ||
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.") | ||
def main( | ||
dataset: str, | ||
other_methods: str, | ||
result_path: str, | ||
max_rows: int | ||
): | ||
allowed_methods = [""] | ||
methods = other_methods.split(",") | ||
for method in methods: | ||
if method not in allowed_methods: | ||
raise ValueError(f"Method {method} not allowed. Allowed methods are {allowed_methods}") | ||
|
||
model_dict = create_model_dict() | ||
ds = datasets.load_dataset(dataset, split="train") | ||
|
||
bench_scores = {} | ||
averages_by_type = defaultdict(list) | ||
averages_by_block_type = defaultdict(list) | ||
for idx, sample in tqdm(enumerate(ds), desc="Running benchmark"): | ||
gt_blocks = json.loads(sample["gt_blocks"]) | ||
doc_type = sample["classification"] | ||
pdf_bytes = sample["pdf"] # This is a single page PDF | ||
marker_html = get_marker_block_html(model_dict, gt_blocks, pdf_bytes) | ||
gt_html = [block["html"] for block in gt_blocks] | ||
scores = score_blocks(gt_html, marker_html) | ||
gt_weights = [len(ht) for ht in gt_html] | ||
overall_score = sum([s * w for s, w in zip(scores, gt_weights)]) / sum(gt_weights) | ||
bench_scores[idx] = { | ||
"scores": scores, | ||
"weights": gt_weights, | ||
"overall_score": overall_score # Weighted score, weighted by length of GT block | ||
} | ||
|
||
averages_by_type[doc_type].append(overall_score) | ||
|
||
for score, gt_block in zip(scores, gt_blocks): | ||
averages_by_block_type[gt_block["block_type"]].append(score) | ||
|
||
if max_rows is not None and idx >= max_rows: | ||
break | ||
|
||
for k in averages_by_type: | ||
averages_by_type[k] = sum(averages_by_type[k]) / len(averages_by_type[k]) | ||
averages_by_type = sorted(averages_by_type.items()) | ||
|
||
print(tabulate.tabulate(averages_by_type, headers=["Document Type", "Average Score"], tablefmt="github")) | ||
|
||
for k in averages_by_block_type: | ||
averages_by_block_type[k] = sum(averages_by_block_type[k]) / len(averages_by_block_type[k]) | ||
averages_by_block_type = sorted(averages_by_block_type.items()) | ||
|
||
print(tabulate.tabulate(averages_by_block_type, headers=["Block Type", "Average Score"], tablefmt="github")) | ||
|
||
overall_average = sum([bench_scores[k]["overall_score"] for k in bench_scores]) / len(bench_scores) | ||
print(tabulate.tabulate([["Overall Average", overall_average]], tablefmt="github")) | ||
|
||
out_path = Path(result_path) / "overall.json" | ||
with open(out_path, "w") as f: | ||
json.dump(bench_scores, f, indent=2) | ||
|
||
print(f"Results saved to {out_path}.") | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import re | ||
from bs4 import BeautifulSoup | ||
|
||
from markdownify import markdownify as md | ||
from rapidfuzz import fuzz | ||
|
||
def standardize_html(html): | ||
soup = BeautifulSoup(html, "html.parser") | ||
|
||
# Convert all headers to h1 so we don't penalize small differences in header levels | ||
for tag in soup.find_all(["h1", "h2", "h3", "h4", "h5", "h6"]): | ||
tag.name = "h1" | ||
|
||
html = str(soup) | ||
markdown = md(html) | ||
markdown = markdown.replace("<br>", "\n") | ||
markdown = re.sub(r"\s+", " ", markdown) | ||
markdown = re.sub(r"\n+", "\n", markdown) | ||
markdown = re.sub("\\.+", ".", markdown) # Replace repeated periods with a single period, like in table of contents | ||
return markdown.strip() | ||
|
||
|
||
def score_blocks(gt_html, method_html): | ||
scores = [] | ||
for gt, method in zip(gt_html, method_html): | ||
gt= standardize_html(gt) | ||
method = standardize_html(method) | ||
score = fuzz.ratio(gt, method) | ||
scores.append(score) | ||
return scores |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.