# --- Main runner (now computes LLM summaries) -------------------------------
def run_extraction_on_parquet(
parquet_path: pathlib.Path,
outdir: pathlib.Path,
*,
only_types: List[str] | None = None,
per_type_limits: Dict[str, int] | None = None,
chunk_type_label: Optional[str] = None,
model_id: Optional[str] = None,
temperature: float = 0.0,
provider_choice: Optional[str] = None,
override_lang_always: bool = True,
) -> Dict[str, Any]:
outdir.mkdir(parents=True, exist_ok=True)
pyd_dir = outdir / "pydantic_json"
if int(os.getenv("LEX_WRITE_PYD_FILES", str(LEX_WRITE_PYD_FILES))):
pyd_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_parquet(parquet_path, columns=["id","text","metadata"])
# Filter by doc types if requested
def _row_doc_type(row: pd.Series) -> str:
meta = _safe_meta(row.get("metadata"))
cid = str(row.get("id"))
dt = _derive_doc_type(cid, meta) or "unknown"
return dt
def filter_df_by_doc_types(df: pd.DataFrame, allowed_types: List[str]) -> pd.DataFrame:
keep = []
for i, row in df.iterrows():
dt = _row_doc_type(row)
if dt in allowed_types: keep.append(i)
return df.loc[keep].copy()
if only_types:
df = filter_df_by_doc_types(df, [t.lower() for t in only_types])
# Per‑type limits (optional)
counts = {}
def select_rows_per_type(df: pd.DataFrame, per_type_limits: Dict[str, int]) -> pd.DataFrame:
counters = {k: 0 for k in per_type_limits.keys()}
keep_idx = []
for i, row in df.iterrows():
dt = _row_doc_type(row)
if dt in per_type_limits and counters[dt] < per_type_limits[dt]:
keep_idx.append(i)
counters[dt] += 1
if all(counters[k] >= per_type_limits[k] for k in per_type_limits):
break
out = df.loc[keep_idx].copy()
out.attrs["counts_by_type"] = counters
return out
if per_type_limits:
df = select_rows_per_type(df, per_type_limits)
counts = df.attrs.get("counts_by_type", {})
# Drop blanks
df["text"] = df["text"].astype(str)
df = df[df["text"].str.strip().ne("")]
if df.empty:
enriched_df = pd.DataFrame(columns=["id","text","metadata"])
out_parquet = outdir / "chunks_enriched.parquet"
try: enriched_df.to_parquet(out_parquet, index=False)
except Exception as e: print("Parquet save failed:", e)
return {
"rows_processed": 0,
"total_extractions": 0,
"jsonl_path": str(outdir / "grounded_extractions.jsonl"),
"vis_path": str(outdir / "visualization.html"),
"pyd_dir": str(pyd_dir),
"provider_used": provider_choice,
"model_used": (model_id or (MODEL_ID_GEMINI if provider_choice=='gemini' else MODEL_ID_OPENAI)),
"counts_by_type": counts,
"enriched_parquet": str(out_parquet),
}
provider = provider_choice or ("gemini" if (GEMINI_KEY) else "openai")
# Records & unique texts (for dedup)
records = []
for row in df.itertuples(index=False):
cid = str(row.id)
text = str(row.text or "")
meta = row.metadata if isinstance(row.metadata, dict) else (json.loads(row.metadata) if isinstance(row.metadata, str) else {})
thash = _hash_text(text)
records.append((cid, text, meta, thash))
unique_by_hash: Dict[str, tuple[str, dict]] = {}
for cid, text, meta, thash in records:
if thash not in unique_by_hash:
unique_by_hash[thash] = (text, meta)
# Precompute summaries (parallel & cached) before extraction stream
_precompute_summaries_for_uniques(unique_by_hash, provider)
limiter = _RateLimiter(LEX_RPS)
def _extract_one(text: str, meta: dict) -> Any:
th = _hash_text(text)
with _EXTRACT_CACHE_LOCK:
cached = _EXTRACT_CACHE.get(th)
if cached is not None:
return cached
if override_lang_always:
_ = _get_lang_fast(text, meta, provider) # compute once; field merge later
limiter.wait()
res = lx.extract(
text_or_documents=text,
prompt_description=PROMPT,
examples=EXAMPLES,
model_id=(model_id or (MODEL_ID_GEMINI if provider=='gemini' else MODEL_ID_OPENAI)),
language_model_params={"temperature": float(temperature)},
)
with _EXTRACT_CACHE_LOCK:
_EXTRACT_CACHE[th] = res
return res
# Parallel extract uniques
to_run = []
with _EXTRACT_CACHE_LOCK:
for thash, (text, meta) in unique_by_hash.items():
if thash not in _EXTRACT_CACHE:
to_run.append((thash, text, meta))
if to_run:
with use_langextract_provider(provider, gemini_key=GEMINI_KEY, openai_key=OPENAI_KEY):
with ThreadPoolExecutor(max_workers=LEX_MAX_WORKERS) as ex:
futs = {ex.submit(_extract_one, text, meta): (thash, text) for thash, text, meta in to_run}
for fut in as_completed(futs):
_ = fut.result()
# Stream JSONL and build enriched rows in original order
jsonl_path = outdir / "grounded_extractions.jsonl"
enriched_rows: List[Dict[str, Any]] = []
with use_langextract_provider(provider, gemini_key=GEMINI_KEY, openai_key=OPENAI_KEY):
with open(jsonl_path, "w", encoding="utf-8") as jf:
for cid, text, meta, thash in records:
with _EXTRACT_CACHE_LOCK:
result = _EXTRACT_CACHE.get(thash)
if result is None:
result = _extract_one(text, meta)
lang_source = lang_llm = lang_final = None
if override_lang_always:
lang_source, lang_llm, lang_final = _get_lang_fast(text, meta, provider)
# --- NEW: LLM summary (dedup + fallback)
summary_text, summary_src = _get_summary_for_text(text, lang_final, provider)
summary_provider_eff = (LEX_SUMMARY_PROVIDER or provider)
summary_model_eff = (MODEL_ID_SUM_GEMINI if summary_provider_eff == "gemini" else MODEL_ID_SUM_OPENAI)
payload = map_extractions_to_schema(
chunk_id=cid,
text=text,
doc_meta=meta,
extractions=result.extractions,
chunk_type_label=chunk_type_label,
lang_llm=lang_llm,
chunk_summary_override=summary_text, # inject LLM summary
)
m = dict(meta)
m["doc_id"] = m.get("doc_id") or payload.doc_id
if payload.doc_type: m["doc_type"] = payload.doc_type
if payload.chunk_type: m["chunk_type"] = payload.chunk_type
if payload.chunk_index is not None: m["chunk_index"] = payload.chunk_index
if payload.token_count is not None: m["token_count"] = payload.token_count
if payload.title: m["title"] = payload.title
if payload.date: m["date"] = payload.date
if payload.source: m["source"] = payload.source
if payload.url is not None: m["url"] = payload.url
if lang_source is not None: m["lang_source"] = lang_source
if payload.lang_llm is not None: m["lang_llm"] = payload.lang_llm
if lang_final is not None: m["lang"] = lang_final
if payload.cpc_codes: m["cpc_codes"] = payload.cpc_codes
if payload.country_code: m["country_code"] = payload.country_code
if payload.content_year is not None: m["content_year"] = payload.content_year
if payload.content_month is not None: m["content_month"] = payload.content_month
m["extraction_count"] = payload.extraction_count
m["span_integrity_pct"] = payload.span_integrity_pct
# --- NEW: write summary + provenance
m["chunk_summary"] = payload.chunk_summary
m["chunk_summary_source"] = summary_src # "llm" or "first_n_words"
m["chunk_summary_provider"] = summary_provider_eff # "gemini" | "openai"
m["chunk_summary_model"] = summary_model_eff
# Entities & others
m["entities"] = payload.entities.model_dump()
m["topic_tags"] = payload.topic_tags
m["event_dates"] = [e.model_dump() for e in payload.event_dates]
m["role_annotations"] = [r.model_dump() for r in payload.role_annotations]
m["numeric_facts"] = [nf.model_dump() for nf in payload.numeric_facts]
enriched_rows.append({"id": cid, "text": text, "metadata": json.dumps(m, ensure_ascii=False)})
# JSON-safe extractions
ex_json = [_jsonable_extraction(ex) for ex in result.extractions]
jf.write(json.dumps({"document_id": cid, "text": text, "extractions": ex_json}, ensure_ascii=False) + "\n")
vis_path = outdir / "visualization.html"
if not LEX_SKIP_STATIC_VIZ:
try:
html_lib = lx.visualize(str(jsonl_path))
with open(vis_path, "w", encoding="utf-8") as f:
f.write(html_lib.data if hasattr(html_lib, "data") else html_lib)
except Exception as e:
print("Static visualization failed (skipped):", e)
# Our interactive viewer
make_interactive_html(jsonl_path, outdir / "visualization_interactive.html")
enriched_df = pd.DataFrame(enriched_rows, columns=["id","text","metadata"])
out_parquet = outdir / "chunks_enriched.parquet"
try:
enriched_df.to_parquet(out_parquet, index=False)
except Exception as e:
print("Parquet save failed:", e)
total_extractions = recount_extractions_from_file(jsonl_path)
return {
"rows_processed": int(len(enriched_rows)),
"total_extractions": int(total_extractions),
"jsonl_path": str(jsonl_path),
"vis_path": str(vis_path),
"pyd_dir": str(pyd_dir),
"provider_used": provider,
"model_used": (model_id or (MODEL_ID_GEMINI if provider=='gemini' else MODEL_ID_OPENAI)),
"counts_by_type": counts,
"enriched_parquet": str(out_parquet),
}