gw/build_distance_matrix.py

66 lines
2.3 KiB
Python

import math
from pathlib import Path
import pandas as pd
def haversine(lat1, lon1, lat2, lon2):
R = 6371.0 # Earth radius in km
phi1, phi2 = math.radians(lat1), math.radians(lat2)
dphi = math.radians(lat2 - lat1)
dlambda = math.radians(lon2 - lon1)
a = math.sin(dphi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
return R * c
def _read_csv_flexible(path: str, header: int = 0):
for enc in ["utf-8", "utf-8-sig", "gbk", "latin1"]:
try:
return pd.read_csv(path, encoding=enc, engine="python")
except UnicodeDecodeError:
continue
# fallback with ignoring errors
return pd.read_csv(path, encoding="utf-8", engine="python", encoding_errors="ignore")
def _infer_coords(df, name_col, lat_col, lon_col):
# Cast numeric columns
df[lat_col] = pd.to_numeric(df[lat_col], errors="coerce")
df[lon_col] = pd.to_numeric(df[lon_col], errors="coerce")
lat_range = df[lat_col].between(-90, 90).mean()
lon_range = df[lon_col].between(-180, 180).mean()
# If swapped, switch
if lat_range < lon_range:
df[lat_col], df[lon_col] = df[lon_col], df[lat_col]
df = df.rename(columns={name_col: "name", lat_col: "lat", lon_col: "lon"})
return df[["name", "lat", "lon"]].dropna()
def main():
demand = _read_csv_flexible("data/DemandLocation.csv")
factory = _read_csv_flexible("data/FactoryLocation.csv")
# Demand: assume first column name, last two numeric
demand_coords = _infer_coords(demand, demand.columns[0], demand.columns[-2], demand.columns[-1])
# Factory: assume first column name, last two numeric
factory_coords = _infer_coords(factory, factory.columns[0], factory.columns[-2], factory.columns[-1])
rows = []
for _, drow in demand_coords.iterrows():
for _, frow in factory_coords.iterrows():
dist = haversine(drow["lat"], drow["lon"], frow["lat"], frow["lon"])
rows.append({"demand_city": drow["name"], "factory": frow["name"], "distance_km": dist})
df = pd.DataFrame(rows)
out_path = Path("data") / "distance_matrix.csv"
df.to_csv(out_path, index=False, encoding="utf-8-sig")
print(f"Saved {len(df)} rows to {out_path}")
if __name__ == "__main__":
main()