gw/build_factory_data.py

74 lines
2.3 KiB
Python

import json
import os
import pandas as pd
from pathlib import Path
def load_factory_mapping(year: int) -> dict:
path = os.path.join("data", str(year), "factory_mapping.json")
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def load_model_params(year: int) -> dict:
path = os.path.join("data", str(year), "model_params.json")
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def load_benchmark_factories(year: int) -> list[str]:
path = Path("data") / str(year) / "benchmark.csv"
encodings = ("utf-8", "utf-8-sig", "gbk")
last_error = None
df = None
for enc in encodings:
try:
df = pd.read_csv(path, encoding=enc)
break
except UnicodeDecodeError as exc:
last_error = exc
continue
if df is None:
raise last_error if last_error else FileNotFoundError(f"Missing {path}")
# Assume first column holds factory name (aligns with factory_mapping keys)
factory_col = df.columns[0]
return df[factory_col].astype(str).str.strip().unique().tolist()
def build_factory_dataframe(year: int) -> pd.DataFrame:
mapping = load_factory_mapping(year)
params = load_model_params(year)
default_factor = params.get("factor_default", 1)
benchmark_factories = load_benchmark_factories(year)
rows = []
for cn_name in benchmark_factories:
code = mapping.get(cn_name)
if code is None:
raise KeyError(f"benchmark.csv中的工厂“{cn_name}”未在 factory_mapping.json 中找到映射。")
factor = params.get(f"factor_{code}", default_factor if default_factor is not None else 1)
rows.append(
{
"工厂中文名": cn_name,
"工厂英文名": code,
"工厂平均磨合系数": float(factor),
"最小误差": 1_000_000,
}
)
return pd.DataFrame(rows)
def write_factory_csv(df: pd.DataFrame, year: int) -> str:
out_path = os.path.join("data", str(year), "factory_data.csv")
df.to_csv(out_path, index=False, encoding="utf-8-sig")
return out_path
def main():
year = 2025
df = build_factory_dataframe(year)
out_path = write_factory_csv(df, year)
print(f"factory_data.csv generated at: {out_path}")
if __name__ == "__main__":
main()