307 lines
12 KiB
Python
307 lines
12 KiB
Python
import argparse
|
||
import json
|
||
import random
|
||
from pathlib import Path
|
||
from typing import List, Tuple
|
||
|
||
import pandas as pd
|
||
import time
|
||
import os
|
||
import gc
|
||
|
||
from simulation_model import SimulationModel
|
||
|
||
POP_SIZE = 20
|
||
GENERATIONS = 50
|
||
MUTATION_RATE = 0.2
|
||
MUTATION_STD = 0.05 # for factors
|
||
LOCK_TIMEOUT = 120 # seconds
|
||
STAGNATION_WINDOW = 5 # generations without improvement before injecting LHS samples
|
||
|
||
|
||
def clip(val: float, bounds: Tuple[float, float]) -> float:
|
||
lo, hi = bounds
|
||
return max(lo, min(hi, val))
|
||
|
||
|
||
def _match_factory_row(df: pd.DataFrame, cols: dict, factory_code: str) -> pd.Series:
|
||
"""Match factory by English name first, fall back to Chinese name."""
|
||
eng_col = cols["工厂英文名"]
|
||
cn_col = cols["工厂中文名"]
|
||
code_norm = factory_code.strip().lower()
|
||
mask_eng = df[eng_col].astype(str).str.strip().str.lower() == code_norm
|
||
if mask_eng.any():
|
||
return df.loc[mask_eng].iloc[0]
|
||
mask_cn = df[cn_col].astype(str).str.strip() == factory_code.strip()
|
||
if mask_cn.any():
|
||
return df.loc[mask_cn].iloc[0]
|
||
raise ValueError(f"在{eng_col}/{cn_col} 中找不到工厂 {factory_code}")
|
||
|
||
|
||
def load_factory_row(csv_path: Path, factory_code: str) -> tuple[pd.Series, pd.DataFrame, dict]:
|
||
df, cols, _ = read_csv_with_encoding(csv_path, required={"工厂中文名", "工厂英文名", "工厂平均磨合系数", "最小误差"})
|
||
row = _match_factory_row(df, cols, factory_code)
|
||
return row, df, cols
|
||
|
||
|
||
def _csv_lock_path(csv_path: Path) -> Path:
|
||
return csv_path.with_suffix(csv_path.suffix + ".lock")
|
||
|
||
|
||
def _acquire_lock(lock_path: Path, timeout: float = LOCK_TIMEOUT, interval: float = 0.5):
|
||
start = time.time()
|
||
while True:
|
||
try:
|
||
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
|
||
os.write(fd, str(os.getpid()).encode())
|
||
return fd
|
||
except FileExistsError:
|
||
if time.time() - start > timeout:
|
||
raise TimeoutError(f"获取锁超时: {lock_path}")
|
||
time.sleep(interval)
|
||
|
||
|
||
def _release_lock(lock_path: Path, fd: int):
|
||
try:
|
||
os.close(fd)
|
||
finally:
|
||
if lock_path.exists():
|
||
try:
|
||
lock_path.unlink()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def update_factory_csv(csv_path: Path, factory_code: str, new_factor: float, new_error: float) -> None:
|
||
lock_path = _csv_lock_path(csv_path)
|
||
fd = _acquire_lock(lock_path)
|
||
try:
|
||
df, cols, enc = read_csv_with_encoding(csv_path, required={"工厂中文名", "工厂英文名", "工厂平均磨合系数", "最小误差"})
|
||
row = _match_factory_row(df, cols, factory_code)
|
||
mask = df.index == row.name
|
||
df.loc[mask, cols["工厂平均磨合系数"]] = float(new_factor)
|
||
df.loc[mask, cols["最小误差"]] = float(new_error)
|
||
df.to_csv(csv_path, index=False, encoding=enc)
|
||
finally:
|
||
_release_lock(lock_path, fd)
|
||
|
||
def update_production_line_csv(csv_path: Path, factory_name_cn: str, line_ids: List[str], best_genes: List[float]) -> None:
|
||
lock_path = _csv_lock_path(csv_path)
|
||
fd = _acquire_lock(lock_path)
|
||
try:
|
||
df, cols, enc = read_csv_with_encoding(csv_path, required={"工厂名", "产线ID", "磨合系数"})
|
||
mask = df[cols["工厂名"]].astype(str).str.strip() == factory_name_cn
|
||
if not mask.any():
|
||
raise ValueError(f"在 {csv_path} 中找不到工厂名: {factory_name_cn}")
|
||
line_to_factor = dict(zip(line_ids, best_genes))
|
||
df.loc[mask, cols["产线ID"]] = df[cols["产线ID"]].astype(str)
|
||
for idx, row in df[mask].iterrows():
|
||
lid = str(row[cols["产线ID"]]).strip()
|
||
if lid in line_to_factor:
|
||
df.at[idx, cols["磨合系数"]] = float(line_to_factor[lid])
|
||
df.to_csv(csv_path, index=False, encoding=enc)
|
||
finally:
|
||
_release_lock(lock_path, fd)
|
||
|
||
|
||
def evaluate(factory_code: str, factory_name_cn: str, line_ids: List[str], genes: List[float]) -> float:
|
||
factory_factors = {}
|
||
model = SimulationModel(
|
||
factory_factors=factory_factors,
|
||
output_enabled=False,
|
||
is_calibration_mode=True,
|
||
)
|
||
# Override per-line factors
|
||
for lid, val in zip(line_ids, genes):
|
||
model.line_factor[lid] = float(val)
|
||
while model.running:
|
||
model.step()
|
||
# Prefer per-factory error ratio; fall back to aggregate model.error if missing.
|
||
if model.factory_error_df is not None and not model.factory_error_df.empty:
|
||
matched = model.factory_error_df[
|
||
model.factory_error_df["name"].astype(str).str.strip() == str(factory_name_cn).strip()
|
||
]
|
||
if not matched.empty and "error_ratio" in matched.columns:
|
||
return float(matched.iloc[0]["error_ratio"])
|
||
return model.error
|
||
|
||
|
||
def mutate(genes: List[float]) -> List[float]:
|
||
new = genes.copy()
|
||
for i in range(len(new)):
|
||
if random.random() < MUTATION_RATE:
|
||
jitter = random.gauss(0, MUTATION_STD)
|
||
new[i] = new[i] + jitter
|
||
return new
|
||
|
||
|
||
def crossover(p1: List[float], p2: List[float]) -> Tuple[List[float], List[float]]:
|
||
if len(p1) == 1:
|
||
return [p1[0]], [p2[0]]
|
||
point = random.randint(1, len(p1) - 1)
|
||
c1 = p1[:point] + p2[point:]
|
||
c2 = p2[:point] + p1[point:]
|
||
return c1, c2
|
||
|
||
|
||
def init_population(seed_vals: List[float]) -> List[List[float]]:
|
||
pop = []
|
||
for idx in range(POP_SIZE):
|
||
if idx == 0:
|
||
pop.append([float(v) for v in seed_vals])
|
||
continue
|
||
indiv = [float(v) for v in seed_vals]
|
||
for j in range(len(indiv)):
|
||
jitter = random.uniform(-0.1, 0.1)
|
||
indiv[j] = indiv[j] + jitter
|
||
pop.append(indiv)
|
||
return pop
|
||
|
||
|
||
def read_csv_with_encoding(path: Path, required: set[str]):
|
||
encodings = ("utf-8", "utf-8-sig", "gbk")
|
||
last_error = None
|
||
df = None
|
||
for enc in encodings:
|
||
try:
|
||
df = pd.read_csv(path, encoding=enc)
|
||
break
|
||
except UnicodeDecodeError as exc:
|
||
last_error = exc
|
||
continue
|
||
if df is None:
|
||
raise last_error if last_error else FileNotFoundError(f"Missing {path}")
|
||
cols = {c.strip(): c for c in df.columns}
|
||
missing = required - set(cols)
|
||
if missing:
|
||
raise ValueError(f"{path} 缺少字段: {', '.join(sorted(missing))}")
|
||
return df, cols, enc
|
||
|
||
|
||
def load_factory_lines(year: int, factory_name_cn: str):
|
||
path = Path("data") / str(year) / "ProductionLine.csv"
|
||
df, cols, enc = read_csv_with_encoding(path, required={"工厂名", "产线ID", "磨合系数", "系数最小值", "系数最大值"})
|
||
mask = df[cols["工厂名"]].astype(str).str.strip() == factory_name_cn
|
||
if not mask.any():
|
||
raise ValueError(f"ProductionLine.csv 中未找到工厂 {factory_name_cn}")
|
||
lines = []
|
||
for _, row in df[mask].iterrows():
|
||
line_id = str(row[cols["产线ID"]]).strip()
|
||
seed = float(row[cols["磨合系数"]])
|
||
min_b = float(row[cols["系数最小值"]])
|
||
max_b = float(row[cols["系数最大值"]])
|
||
lines.append((line_id, seed, min_b, max_b))
|
||
return lines
|
||
|
||
|
||
def apply_bounds(genes: List[float], bounds: List[Tuple[float, float]]) -> List[float]:
|
||
return [clip(val, b) for val, b in zip(genes, bounds)]
|
||
|
||
|
||
def latin_hypercube_samples(n_samples: int, bounds: List[Tuple[float, float]]) -> List[List[float]]:
|
||
if n_samples <= 0:
|
||
return []
|
||
dims = len(bounds)
|
||
samples = []
|
||
# Latin hypercube with per-dimension random permutations and jitter inside each stratum
|
||
strata = [list(range(n_samples)) for _ in range(dims)]
|
||
for s in strata:
|
||
random.shuffle(s)
|
||
for i in range(n_samples):
|
||
point = []
|
||
for d in range(dims):
|
||
lo, hi = bounds[d]
|
||
# random point inside the i-th stratum of dimension d
|
||
u = random.random()
|
||
stratum_idx = strata[d][i]
|
||
frac = (stratum_idx + u) / n_samples
|
||
val = lo + frac * (hi - lo)
|
||
point.append(val)
|
||
samples.append(point)
|
||
random.shuffle(samples)
|
||
return samples
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="GA calibration for a single factory factor.")
|
||
parser.add_argument(
|
||
"--factory",
|
||
required=True,
|
||
help="Factory English code (matches '工厂英文名' in factory_data.csv).",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# set year
|
||
year = json.load(open("year.json", "r", encoding="utf-8"))["year"]
|
||
filename = f"{year}"
|
||
csv_path = Path("data") / filename / "factory_data.csv"
|
||
line_csv_path = Path("data") / filename / "ProductionLine.csv"
|
||
|
||
factory_row, factory_df, factory_cols = load_factory_row(csv_path, args.factory)
|
||
factory_name_cn = str(factory_row[factory_cols["工厂中文名"]]).strip()
|
||
seed_lines = load_factory_lines(year, factory_name_cn)
|
||
line_ids = [lid for lid, _, _, _ in seed_lines]
|
||
seed_vals = [seed for _, seed, _, _ in seed_lines]
|
||
bounds = [(mn, mx) for _, _, mn, mx in seed_lines]
|
||
prev_best_error = (
|
||
float(factory_row[factory_cols["最小误差"]])
|
||
if pd.notna(factory_row[factory_cols["最小误差"]])
|
||
else float("inf")
|
||
)
|
||
|
||
print(f"[START] 校准工厂 {args.factory} / {factory_name_cn} (产线数={len(line_ids)}, baseline_error={prev_best_error:.6f})")
|
||
|
||
best_genes = None
|
||
best_score = float("inf")
|
||
last_improve_gen = -1
|
||
population = init_population(seed_vals)
|
||
|
||
for gen in range(GENERATIONS):
|
||
if gen % 10 == 0:
|
||
gc.collect()
|
||
scored = []
|
||
for i,indiv in enumerate(population):
|
||
if i % 5 == 0:
|
||
gc.collect()
|
||
indiv = apply_bounds(indiv, bounds)
|
||
score = evaluate(args.factory, factory_name_cn, line_ids, indiv)
|
||
# print(f"[{args.factory}] Gen {gen+1} try factors={indiv} -> error={score:.6f}")
|
||
scored.append((score, indiv))
|
||
if score < best_score:
|
||
best_score = score
|
||
best_genes = indiv
|
||
last_improve_gen = gen
|
||
scored.sort(key=lambda x: x[0])
|
||
next_pop = [scored[0][1]]
|
||
while len(next_pop) < POP_SIZE:
|
||
parents = random.sample(scored[:max(3, len(scored))], 2)
|
||
c1, c2 = crossover(parents[0][1], parents[1][1])
|
||
next_pop.append(apply_bounds(mutate(c1), bounds))
|
||
if len(next_pop) < POP_SIZE:
|
||
next_pop.append(apply_bounds(mutate(c2), bounds))
|
||
# Stagnation: inject Latin Hypercube samples to escape local optima
|
||
if last_improve_gen >= 0 and (gen - last_improve_gen) >= STAGNATION_WINDOW:
|
||
lhs_samples = latin_hypercube_samples(max(POP_SIZE // 2, 2), bounds)
|
||
lhs_samples = [apply_bounds(s, bounds) for s in lhs_samples]
|
||
next_pop = next_pop[: POP_SIZE // 4] + lhs_samples
|
||
next_pop = next_pop[:POP_SIZE]
|
||
last_improve_gen = gen
|
||
print(f"[{args.factory}] Stagnation detected ({STAGNATION_WINDOW} gens). Injected {len(lhs_samples)} LHS samples.")
|
||
population = next_pop
|
||
print(f"[{args.factory}] Generation {gen+1}/{GENERATIONS}: best_error={best_score:.6f}")
|
||
|
||
best_genes = apply_bounds(best_genes, bounds)
|
||
best_avg_factor = sum(best_genes) / len(best_genes)
|
||
print(f"[DONE] {args.factory}: best_error={best_score:.6f} (prev best {prev_best_error:.6f})")
|
||
|
||
if best_score < prev_best_error:
|
||
update_factory_csv(csv_path, args.factory, best_avg_factor, best_score)
|
||
update_production_line_csv(line_csv_path, factory_name_cn, line_ids, best_genes)
|
||
print(f"[UPDATE] {args.factory} / {factory_name_cn}: avg_factor={best_avg_factor:.6f}, error={best_score:.6f} 已写入 {csv_path} 与 ProductionLine.csv")
|
||
else:
|
||
print(f"[SKIP] {args.factory}: 未优于历史最小误差,CSV 未更新。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|