Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -539,13 +539,14 @@ def build_graph(
ocr_kwargs: dict[str, Any] = {}
if extract_params.method in ("pdfium_hybrid", "ocr") and extract_params.extract_text:
ocr_kwargs["extract_text"] = True
if extract_params.extract_tables and not extract_params.use_table_structure:
if extract_params.extract_tables:
ocr_kwargs["extract_tables"] = True
if extract_params.extract_charts and not extract_params.use_graphic_elements:
ocr_kwargs["extract_charts"] = True
if extract_params.extract_infographics:
ocr_kwargs["extract_infographics"] = True
ocr_kwargs["use_graphic_elements"] = extract_params.use_graphic_elements
ocr_kwargs["use_table_structure"] = extract_params.use_table_structure
if extract_params.ocr_invoke_url:
ocr_kwargs["ocr_invoke_url"] = extract_params.ocr_invoke_url
if extract_params.api_key:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def _parse_mode_enabled(extract_params: ExtractParams) -> bool:
def _ocr_stage_needed(extract_params: ExtractParams) -> bool:
if extract_params.method in ("pdfium_hybrid", "ocr") and extract_params.extract_text:
return True
if extract_params.extract_tables and not extract_params.use_table_structure:
if extract_params.extract_tables:
# OCR is always needed for table crops: either to produce pseudo-markdown
# (when use_table_structure=False) or to join against the
# table_structure_v1 detections published by TableStructureActor.
return True
if extract_params.extract_charts and not extract_params.use_graphic_elements:
return True
Expand Down Expand Up @@ -304,10 +307,13 @@ def _run_detection_pipeline(self, batch_df: pd.DataFrame) -> pd.DataFrame:
graphic_kwargs["api_key"] = extract_params.api_key
batch_df = self._instantiate_resolved(GraphicElementsActor, **graphic_kwargs).run(batch_df)

ocr_kwargs: dict[str, Any] = {"use_graphic_elements": extract_params.use_graphic_elements}
ocr_kwargs: dict[str, Any] = {
"use_graphic_elements": extract_params.use_graphic_elements,
"use_table_structure": extract_params.use_table_structure,
}
if extract_params.method in ("pdfium_hybrid", "ocr") and extract_params.extract_text:
ocr_kwargs["extract_text"] = True
if extract_params.extract_tables and not extract_params.use_table_structure:
if extract_params.extract_tables:
ocr_kwargs["extract_tables"] = True
if extract_params.extract_charts and not extract_params.use_graphic_elements:
ocr_kwargs["extract_charts"] = True
Expand Down
1 change: 1 addition & 0 deletions nemo_retriever/src/nemo_retriever/ocr/cpu_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, **ocr_kwargs: Any) -> None:
self.ocr_kwargs["extract_charts"] = bool(self.ocr_kwargs.get("extract_charts", False))
self.ocr_kwargs["extract_infographics"] = bool(self.ocr_kwargs.get("extract_infographics", False))
self.ocr_kwargs["use_graphic_elements"] = bool(self.ocr_kwargs.get("use_graphic_elements", False))
self.ocr_kwargs["use_table_structure"] = bool(self.ocr_kwargs.get("use_table_structure", False))
self.ocr_kwargs["request_timeout_s"] = float(self.ocr_kwargs.get("request_timeout_s", 120.0))
self.ocr_kwargs["inference_batch_size"] = int(self.ocr_kwargs.get("inference_batch_size", 8))

Expand Down
1 change: 1 addition & 0 deletions nemo_retriever/src/nemo_retriever/ocr/gpu_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, **ocr_kwargs: Any) -> None:
self.ocr_kwargs["extract_charts"] = bool(self.ocr_kwargs.get("extract_charts", False))
self.ocr_kwargs["extract_infographics"] = bool(self.ocr_kwargs.get("extract_infographics", False))
self.ocr_kwargs["use_graphic_elements"] = bool(self.ocr_kwargs.get("use_graphic_elements", False))
self.ocr_kwargs["use_table_structure"] = bool(self.ocr_kwargs.get("use_table_structure", False))
self.ocr_kwargs["request_timeout_s"] = float(self.ocr_kwargs.get("request_timeout_s", 120.0))
self.ocr_kwargs["inference_batch_size"] = int(self.ocr_kwargs.get("inference_batch_size", 8))

Expand Down
66 changes: 63 additions & 3 deletions nemo_retriever/src/nemo_retriever/ocr/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import pandas as pd
from nemo_retriever.params import RemoteRetryParams
from nemo_retriever.nim.nim import invoke_image_inference_batches
from nemo_retriever.utils.table_and_chart import join_graphic_elements_and_ocr_output
from nemo_retriever.utils.table_and_chart import (
join_graphic_elements_and_ocr_output,
join_table_structure_and_ocr_output,
)

try:
from PIL import Image
Expand Down Expand Up @@ -441,6 +444,46 @@ def _find_ge_detections_for_bbox(
return None


def _find_ts_detections_for_bbox(
row: Any,
table_bbox: Sequence[float],
) -> Optional[Tuple[List[Dict[str, Any]], Optional[Tuple[int, int]]]]:
"""Find table-structure detections + crop size for a table bbox.

Reads the ``table_structure_v1`` column from *row* and returns the
``(detections, (H, W))`` tuple for the region whose ``bbox_xyxy_norm``
matches *table_bbox*. Returns ``None`` if the column is missing, no
region matches, or the matching region has no detections.
"""
ts_col = getattr(row, "table_structure_v1", None)
if not isinstance(ts_col, dict):
return None
regions = ts_col.get("regions")
if not isinstance(regions, list):
return None

for region in regions:
if not isinstance(region, dict):
continue
region_bbox = region.get("bbox_xyxy_norm")
if not isinstance(region_bbox, (list, tuple)) or len(region_bbox) != 4:
continue
if not _bboxes_close(table_bbox, region_bbox):
continue
dets = region.get("detections")
if not isinstance(dets, list) or not dets:
return None
hw = region.get("orig_shape_hw")
hw_t: Optional[Tuple[int, int]] = None
if isinstance(hw, (list, tuple)) and len(hw) == 2:
try:
hw_t = (int(hw[0]), int(hw[1]))
except (TypeError, ValueError):
hw_t = None
return (dets, hw_t)
return None


# ---------------------------------------------------------------------------
# Core function
# ---------------------------------------------------------------------------
Expand All @@ -458,6 +501,7 @@ def ocr_page_elements(
extract_charts: bool = False,
extract_infographics: bool = False,
use_graphic_elements: bool = False,
use_table_structure: bool = False,
inference_batch_size: int = 8,
remote_retry: RemoteRetryParams | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -605,7 +649,16 @@ def ocr_page_elements(
crop_hw_table = (_ch, _cw)
except Exception:
pass
text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw_table) or _blocks_to_text(blocks)
text = ""
if use_table_structure:
ts_match = _find_ts_detections_for_bbox(row, bbox)
if ts_match is not None:
ts_dets, ts_hw = ts_match
text = join_table_structure_and_ocr_output(ts_dets, preds, ts_hw or crop_hw_table)
if not text:
text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw_table) or _blocks_to_text(
blocks
)
else:
text = _blocks_to_text(blocks)
entry = {"bbox_xyxy_norm": bbox, "text": text}
Expand Down Expand Up @@ -646,7 +699,14 @@ def _append_local_result(
return
blocks = _parse_ocr_result(preds)
if label_name == "table":
text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw)
text = ""
if use_table_structure:
ts_match = _find_ts_detections_for_bbox(row, bbox)
if ts_match is not None:
ts_dets, ts_hw = ts_match
text = join_table_structure_and_ocr_output(ts_dets, preds, ts_hw or crop_hw)
if not text:
text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw)
if not text:
text = _blocks_to_text(blocks)
else:
Expand Down
2 changes: 0 additions & 2 deletions nemo_retriever/src/nemo_retriever/table/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class TableExtractionStageConfig:
@dataclass(frozen=True)
class TableStructureOCRStageConfig:
table_structure_invoke_url: str = ""
ocr_invoke_url: str = ""
api_key: str = ""
request_timeout_s: float = 60.0

Expand All @@ -43,7 +42,6 @@ def load_table_structure_ocr_config_from_dict(cfg: Dict[str, Any]) -> TableStruc
cfg = dict(cfg or {})
return TableStructureOCRStageConfig(
table_structure_invoke_url=str(cfg.get("table_structure_invoke_url") or ""),
ocr_invoke_url=str(cfg.get("ocr_invoke_url") or ""),
api_key=str(cfg.get("api_key") or ""),
request_timeout_s=float(cfg.get("request_timeout_s", 60.0)),
)
1 change: 0 additions & 1 deletion nemo_retriever/src/nemo_retriever/table/cpu_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
*,
table_structure_invoke_url: Optional[str] = None,
invoke_url: Optional[str] = None,
ocr_invoke_url: Optional[str] = None,
api_key: Optional[str] = None,
table_output_format: Optional[str] = None,
request_timeout_s: float = 120.0,
Expand Down
1 change: 0 additions & 1 deletion nemo_retriever/src/nemo_retriever/table/gpu_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
*,
table_structure_invoke_url: Optional[str] = None,
invoke_url: Optional[str] = None,
ocr_invoke_url: Optional[str] = None,
api_key: Optional[str] = None,
table_output_format: Optional[str] = None,
request_timeout_s: float = 120.0,
Expand Down
31 changes: 28 additions & 3 deletions nemo_retriever/src/nemo_retriever/table/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,13 @@ def table_structure_ocr_page_elements(
Returns
-------
pandas.DataFrame
Original columns plus ``table`` and ``table_structure_ocr_v1``.
Original columns plus:

- ``table``: list of per-crop dicts with structure-only ``text`` fallback
(overwritten by the OCR stage when it runs with ``use_table_structure=True``).
- ``table_structure_v1``: page-level ``{regions, timing, error}`` payload
consumed by the OCR stage to join OCR text with structure detections.
- ``table_structure_ocr_v1``: per-row timing/error metadata.
"""
from nemo_retriever.nim.nim import invoke_image_inference_batches
from nemo_retriever.ocr.ocr import _crop_all_from_page, _np_rgb_to_b64_png
Expand Down Expand Up @@ -310,12 +316,14 @@ def table_structure_ocr_page_elements(

# Per-row accumulators.
all_table: List[List[Dict[str, Any]]] = []
all_ts_payloads: List[Dict[str, Any]] = []
all_meta: List[Dict[str, Any]] = []

t0_total = time.perf_counter()

for row in batch_df.itertuples(index=False):
table_items: List[Dict[str, Any]] = []
ts_regions: List[Dict[str, Any]] = []
row_error: Any = None

try:
Expand All @@ -333,6 +341,7 @@ def table_structure_ocr_page_elements(

if not isinstance(page_image_b64, str) or not page_image_b64:
all_table.append(table_items)
all_ts_payloads.append({"regions": ts_regions, "timing": None, "error": None})
all_meta.append({"timing": None, "error": None})
continue

Expand All @@ -341,6 +350,7 @@ def table_structure_ocr_page_elements(

if not crops:
all_table.append(table_items)
all_ts_payloads.append({"regions": ts_regions, "timing": None, "error": None})
all_meta.append({"timing": None, "error": None})
continue

Expand Down Expand Up @@ -386,8 +396,10 @@ def table_structure_ocr_page_elements(
structure_results.append([d for d in dets if (d.get("score") or 0.0) >= YOLOX_TABLE_MIN_SCORE])

# --- Pass 3: Build structure-only output per crop ---
for crop_i, (_, bbox, _) in enumerate(crops):
for crop_i, (_, bbox, crop_array) in enumerate(crops):
structure_dets = structure_results[crop_i]
crop_hw = (int(crop_array.shape[0]), int(crop_array.shape[1]))
counts = _count_structure_labels(structure_dets)
table_items.append(
{
"bbox_xyxy_norm": bbox,
Expand All @@ -396,7 +408,16 @@ def table_structure_ocr_page_elements(
table_output_format=table_output_format,
),
"structure_detections": structure_dets,
"structure_counts": _count_structure_labels(structure_dets),
"structure_counts": counts,
}
)
ts_regions.append(
{
"bbox_xyxy_norm": [float(x) for x in bbox],
"label_name": "table",
"detections": structure_dets,
"orig_shape_hw": [crop_hw[0], crop_hw[1]],
"structure_counts": counts,
}
)

Expand All @@ -410,14 +431,18 @@ def table_structure_ocr_page_elements(
}

all_table.append(table_items)
all_ts_payloads.append({"regions": ts_regions, "timing": None, "error": row_error})
all_meta.append({"timing": None, "error": row_error})

elapsed = time.perf_counter() - t0_total
for meta in all_meta:
meta["timing"] = {"seconds": float(elapsed)}
for payload in all_ts_payloads:
payload["timing"] = {"seconds": float(elapsed)}

out = batch_df.copy()
out["table"] = all_table
out["table_structure_v1"] = all_ts_payloads
out["table_structure_ocr_v1"] = all_meta
return out

Expand Down
38 changes: 0 additions & 38 deletions nemo_retriever/src/nemo_retriever/utils/table_and_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def _join_yolox_table_structure_and_ocr_output(
df_table = df_assign[df_assign["is_table"]].reset_index(drop=True)
if len(df_table):
mat = build_markdown(df_table)
mat = _trim_non_table_edge_rows(mat)
markdown_table = display_markdown(mat, use_header=True)

all_boxes = np.stack(df_table.ocr_box.values)
Expand Down Expand Up @@ -363,43 +362,6 @@ def remove_empty_row(mat: list) -> list:
return mat_filter


def _trim_non_table_edge_rows(mat: list) -> list:
"""Remove leading/trailing rows that look like non-table content.

Heuristics applied only to edge rows:
- All non-empty cells contain identical text (duplicated caption).
- Less than half the cells are filled (stray text from surrounding content).
"""
if len(mat) <= 1:
return mat

n_cols = max(len(row) for row in mat) if mat else 0
if n_cols < 2:
return mat

def _is_noise_row(row: list) -> bool:
non_empty = [c for c in row if c.strip()]
if not non_empty:
return True
# All non-empty cells identical (repeated caption text).
if len(non_empty) > 1 and len(set(non_empty)) == 1:
return True
# Half or fewer cells filled.
if len(non_empty) <= n_cols / 2:
return True
return False

# Trim leading noise rows.
while len(mat) > 1 and _is_noise_row(mat[0]):
mat = mat[1:]

# Trim trailing noise rows.
while len(mat) > 1 and _is_noise_row(mat[-1]):
mat = mat[:-1]

return mat


def reorder_boxes(
boxes: np.ndarray,
texts: list,
Expand Down
Loading
Loading