286 lines
14 KiB
Python
Executable file
286 lines
14 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
import os
|
|
import argparse
|
|
from pathlib import Path
|
|
from functools import reduce
|
|
import json
|
|
import subprocess
|
|
import ipaddress
|
|
import socket
|
|
|
|
script_path = Path(__file__).parent
|
|
local_config_path = script_path / "configs"
|
|
include_config = local_config_path / "config.common.inc"
|
|
ss_config_paths = [Path("/etc/shadowsocks"), Path("/etc/shadowsocks-rust"), Path("/etc/shadowsocks-libev")]
|
|
ss_service_name = "shadowsocks-libev-redir@" # "shadowsocks-rust@"
|
|
ss_prefix = "autogen-"
|
|
ss_config_path =[p for p in ss_config_paths if p.exists()][0]
|
|
ss_config_get = lambda name: ss_config_path / f"{ss_prefix}{name}.json"
|
|
ss_service_get = lambda name: f"{ss_service_name}{ss_prefix}{name}.service"
|
|
nft_rule_redir = script_path / "transparent-proxy.nft"
|
|
nft_rule_v6_redir = script_path / "transparent-proxy-v6.nft"
|
|
nft_rule_tproxy = script_path / "transparent-proxy-tproxy.nft"
|
|
nft_rule_v6_tproxy = script_path / "transparent-proxy-v6-tproxy.nft"
|
|
chnroute = "/etc/dnsmasq.d/chinadns_chnroute.txt"
|
|
chnroute6 = "/etc/dnsmasq.d/chinadns_chnroute6.txt"
|
|
proxy_interfaces = []
|
|
proxy_interfaces_v6 = [] or proxy_interfaces
|
|
extra_bypass = []
|
|
|
|
|
|
def gen_configs(config_name: str) -> dict:
|
|
config_inc = json.loads(include_config.read_text())
|
|
RUST_ATTRS = {'mv': ('local_address', 'local_port', 'mode', 'protocol'), 'mvdel': ('tcp_redir', 'udp_redir')}
|
|
assert config_inc['common']['tcp_redir'] in ('redirect', 'tproxy')
|
|
assert config_inc['common']['udp_redir'] == 'tproxy'
|
|
rust_only = config_inc['common']['tcp_redir'] != 'redirect'
|
|
config = json.loads((local_config_path / f"{config_name}.json").read_text())
|
|
MODES = ("ipv4_tcp_udp", "ipv6_tcp_udp", "ipv4_tcp_only", "ipv4_udp_only", "ipv6_tcp_only", "ipv6_udp_only", "ipv4_ipv6_tcp_udp")
|
|
# handle legacy config
|
|
config_common = config if "modes" not in config else config["common"]
|
|
config_inc["common"] = dict(sorted({**config_inc["common"], **config_common}.items()))
|
|
if "modes" in config:
|
|
config_inc["modes"] = config["modes"]
|
|
assert all(m in MODES for m in config_inc["modes"])
|
|
assert config_inc["modes"]
|
|
def assert_overlap():
|
|
_f = lambda ipx, proto: len([m for m in config["modes"] if ipx in m and proto in m]) > 1
|
|
IPX = ("ipv4", "ipv6")
|
|
L4PROTO = ("tcp", "udp")
|
|
assert not any(_f(ipx, proto) for ipx in IPX for proto in L4PROTO) # has overlap
|
|
ipvx_enabled = lambda ipx: any(ipx in m for m in config_inc['modes'])
|
|
# has both tcp and udp
|
|
assert all(any(True for m in config["modes"] if ipx in m and proto in m) for ipx in IPX if ipvx_enabled(ipx) for proto in L4PROTO)
|
|
assert_overlap()
|
|
for m in config_inc["modes"]:
|
|
config_inc[m] = dict(sorted({**config_inc["common"], **config_inc.get(m, dict()), **config.get(m, dict())}.items()))
|
|
for idx, m in enumerate(config_inc["modes"]):
|
|
config_inc[m] = dict(sorted({**config_inc["common"], **config_inc.get(m, dict())}.items()))
|
|
if rust_only:
|
|
config_inc[m]['locals'] = [{k: config_inc[m][k] for k in reduce(lambda x,y:x+y, RUST_ATTRS.values())}]
|
|
for _a in reduce(lambda x,y:x+y, RUST_ATTRS.values()):
|
|
config_inc[m][f"#{_a}"] = config_inc[m].pop(_a, None)
|
|
else:
|
|
for _a in reduce(lambda x,y:x+y, RUST_ATTRS.values()):
|
|
config_inc[m][f"#{_a}"] = config_inc[m].get(_a, None)
|
|
for _a in RUST_ATTRS['mvdel']:
|
|
config_inc[m].pop(_a, None)
|
|
if idx == 0:
|
|
config_inc[m]['_meta_name'] = config_name
|
|
return config_inc
|
|
|
|
def print_config_names(do_print=True) -> str:
|
|
def get_current_up() -> str:
|
|
primary_conf = ss_config_get(0)
|
|
try:
|
|
if primary_conf.exists():
|
|
current_up = json.loads(primary_conf.read_text())['_meta_name']
|
|
return current_up
|
|
except Exception:
|
|
return ""
|
|
current_up = get_current_up()
|
|
if do_print:
|
|
for conf in local_config_path.iterdir():
|
|
if conf.name.endswith('.json'):
|
|
name = conf.name[:-len('.json')]
|
|
_c = gen_configs(name)
|
|
c = _c[_c["modes"][0]]
|
|
server_info = " %s \t(%s:%d)" % (name, c["server"], c["server_port"])
|
|
if name == current_up:
|
|
server_info = ">" + server_info[1:]
|
|
print(server_info)
|
|
return current_up
|
|
|
|
def stop_and_remove(config_name):
|
|
service = ss_service_get(config_name)
|
|
if not subprocess.run(["systemctl", "is-active", service], check=False, capture_output=True).returncode:
|
|
if subprocess.run(["systemctl", "stop", service], check=False).returncode:
|
|
print(f"[!] systemctl stop {service} failed")
|
|
ss_config_get(config_name).unlink()
|
|
|
|
def stop_all_configs():
|
|
for conf in ss_config_path.iterdir():
|
|
if conf.name.endswith(".json") and conf.name.startswith(ss_prefix):
|
|
name = conf.name[len(ss_prefix):-len(".json")]
|
|
service = ss_service_get(name)
|
|
if not subprocess.run(["systemctl", "is-active", service], check=False, capture_output=True).returncode:
|
|
if subprocess.run(["systemctl", "stop", service], check=False).returncode:
|
|
print(f"[!] systemctl stop {service} failed")
|
|
print(f"stopped {service}")
|
|
|
|
def write_and_enable_configs(config_dict, dry_run=False) -> bool:
|
|
changed = [False, False, False]
|
|
def mark_changed(x):
|
|
changed[x] = True
|
|
idx_to_name = {k: v for k, v in enumerate(config_dict['modes'])}
|
|
for conf in ss_config_path.iterdir():
|
|
if conf.name.endswith(".json") and conf.name.startswith(ss_prefix):
|
|
name = conf.name[len(ss_prefix):-len(".json")]
|
|
try:
|
|
idx = int(name)
|
|
assert idx in idx_to_name
|
|
except Exception:
|
|
if dry_run:
|
|
print(f"check failed: should stop and remove {conf.name=}")
|
|
else:
|
|
stop_and_remove(name)
|
|
mark_changed(0)
|
|
for idx, name in enumerate(config_dict['modes']):
|
|
cfgname = str(idx)
|
|
cfg = ss_config_get(cfgname)
|
|
old = cfg.read_text() if cfg.exists() else ""
|
|
new = json.dumps({k:v for k, v in config_dict[name].items() if not k.startswith("#")})
|
|
config_same = new == old
|
|
if not config_same:
|
|
if dry_run:
|
|
print(f"check failed: should write {cfgname} {name}")
|
|
else:
|
|
cfg.write_text(new)
|
|
mark_changed(1)
|
|
systemd_ret = subprocess.run(["systemctl", "is-active", ss_service_get(cfgname)], check=False, capture_output=True).returncode
|
|
def restart_service(name):
|
|
service = ss_service_get(name)
|
|
if dry_run:
|
|
print(f"check failed: should start {service}")
|
|
else:
|
|
if subprocess.run(["systemctl", "restart", service], check=False).returncode:
|
|
print(f"[!] systemctl start {service} failed")
|
|
mark_changed(2)
|
|
if systemd_ret:
|
|
restart_service(cfgname)
|
|
else:
|
|
if not config_same:
|
|
restart_service(cfgname)
|
|
if changed[0]:
|
|
print("deleted old config")
|
|
if changed[1]:
|
|
print("wrote new config")
|
|
if changed[2]:
|
|
print("restart systemd")
|
|
|
|
def invoke_self_with_sudo():
|
|
assert os.getuid() != 0
|
|
import sys
|
|
return subprocess.run(["sudo", sys.executable, *sys.argv], check=False).returncode
|
|
|
|
def prepare_cgroup_path():
|
|
CGv2_ROOT = Path('/sys/fs/cgroup')
|
|
needed_slices = ('ss_bp.slice', 'ss_bp_tcp.slice', 'ss_bp_udp.slice', 'ss_fw.slice', 'ss_fw_tcp.slice', 'ss_fw_udp.slice')
|
|
for slice in needed_slices:
|
|
(CGv2_ROOT / slice).mkdir(exist_ok=True)
|
|
|
|
def process_nft_rule(configs: dict) -> list:
|
|
nft_rule, nft_rule_v6 = (nft_rule_redir, nft_rule_v6_redir) \
|
|
if configs['common']['tcp_redir'] == 'redirect' \
|
|
else (nft_rule_tproxy, nft_rule_v6_tproxy)
|
|
def get_family_proto_config(family: int, l4proto: str) -> str:
|
|
filter_family = [m for m in configs['modes'] if f"ipv{family}" in m]
|
|
mode = [m for m in filter_family if l4proto in m][0]
|
|
return mode
|
|
def process_nft_rule(family: int) -> str:
|
|
nft_lines = list(filter(None, (nft_rule_v6 if family == 6 else nft_rule).read_text().split('\n')))
|
|
nft_lines = nft_lines[nft_lines.index('## DO NOT CHANGE THIS LINE'):]
|
|
|
|
_tcp = configs[get_family_proto_config(family, 'tcp')]
|
|
_udp = configs[get_family_proto_config(family, 'udp')]
|
|
def get_server(hostname_or_ip: str):
|
|
try:
|
|
server = ipaddress.ip_address(hostname_or_ip)
|
|
except ValueError:
|
|
server = ipaddress.ip_address(socket.getaddrinfo(hostname_or_ip, None, type=socket.SOCK_RAW)[0][4][0])
|
|
return server
|
|
_tcp_server = get_server(_tcp['server'])
|
|
_udp_server = get_server(_udp['server'])
|
|
proxy_ifs_real = proxy_interfaces_v6 if family == 6 else proxy_interfaces
|
|
nft_define = {
|
|
'tcp_host': f"@empty_ipv{family}" if _tcp_server.version != family else str(_tcp_server),
|
|
'udp_host': f"@empty_ipv{family}" if _udp_server.version != family else str(_udp_server),
|
|
'tcp_proxy_ifnames': "{ %s }" % ', '.join([f'"{x}"' for x in proxy_ifs_real]) if proxy_ifs_real else '@empty_str',
|
|
'udp_proxy_ifnames': "{ %s }" % ', '.join([f'"{x}"' for x in proxy_ifs_real]) if proxy_ifs_real else '@empty_str',
|
|
'tcp_server_port': _tcp['server_port'],
|
|
'udp_server_port': _udp['server_port'],
|
|
'tcp_local_port': _tcp['#local_port'],
|
|
'udp_local_port': _udp['#local_port']
|
|
}
|
|
nft_lines = [f"define {k} = {v}" for k, v in nft_define.items()] + nft_lines
|
|
return '\n'.join(nft_lines)
|
|
ipvx_enabled = lambda x: any(f"ipv{x}" in m for m in configs['modes'])
|
|
return {x: process_nft_rule(x) for x in (4, 6) if ipvx_enabled(x)}
|
|
|
|
def flush_nft() -> bool:
|
|
nft = '\n'.join((
|
|
'add table ip transparent_proxy',
|
|
'delete table ip transparent_proxy',
|
|
'add table ip6 transparent_proxy_v6',
|
|
'delete table ip6 transparent_proxy_v6',
|
|
'add table ip6 output_deny',
|
|
'delete table ip6 output_deny',
|
|
)).encode('utf-8')
|
|
if subprocess.run(["nft", "-f", "-"], input=nft, check=False).returncode:
|
|
print("[!] nft flush failed")
|
|
return False
|
|
return True
|
|
|
|
def flush_iproute2() -> None:
|
|
ip_batch = '\n'.join(('route flush table 100', 'rule del fwmark 0xdeaf table 100')).encode('utf-8')
|
|
subprocess.run(["ip", "-force", "-batch", "-"], input=ip_batch, check=False, stderr=subprocess.DEVNULL)
|
|
subprocess.run(["ip", "-6", "-force", "-batch", "-"], input=ip_batch, check=False, stderr=subprocess.DEVNULL) # always run v6 cleanup
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='ss.py')
|
|
parser.add_argument('action', type=str, default='info', nargs='?', choices=['info', 'up', 'down'], help='what to do')
|
|
parser.add_argument('config', type=str, default=None, nargs='?', help='config name')
|
|
parser.add_argument('-s', '--stop-all', action='store_true', help='stop systemd units')
|
|
args = parser.parse_args()
|
|
if args.action == 'info':
|
|
name = print_config_names()
|
|
if name:
|
|
if (local_config_path / f"{name}.json").exists():
|
|
write_and_enable_configs(gen_configs(name), dry_run=True)
|
|
else:
|
|
print(f"[!] current config {name}.json is missing")
|
|
elif args.action == 'up':
|
|
if os.getuid() != 0:
|
|
return invoke_self_with_sudo()
|
|
prepare_cgroup_path()
|
|
if not args.config:
|
|
name = print_config_names(do_print=False)
|
|
args.config = name
|
|
print("autoselected config %s" % name)
|
|
assert args.config
|
|
configs = gen_configs(args.config)
|
|
write_and_enable_configs(configs)
|
|
ipvx_enabled = lambda x: any(f"ipv{x}" in m for m in configs['modes'])
|
|
nfts = {k: v.encode('utf-8') for k, v in process_nft_rule(configs).items()}
|
|
flush_iproute2()
|
|
ip_batch = '\n'.join(('route add local default dev lo table 100', 'rule add fwmark 0xdeaf table 100')).encode('utf-8')
|
|
for x in (4, 6):
|
|
if ipvx_enabled(x):
|
|
if subprocess.run(["ip", f"-{x}", "-force", "-batch", "-"], input=ip_batch, check=False).returncode:
|
|
print(f"[!] iproute2 ipv{x} failed")
|
|
flush_nft()
|
|
for x, nft in nfts.items():
|
|
if subprocess.run(["nft", "-f", "-"], input=nft, check=False).returncode:
|
|
print(f"[!] nft ipv{x} failed, flushing")
|
|
flush_nft()
|
|
break
|
|
else:
|
|
bp = [ipaddress.ip_network(net) for net in extra_bypass]
|
|
for x in (4, 6):
|
|
if ipvx_enabled(x):
|
|
nft_chnroute = list(filter(None, Path(chnroute6 if x==6 else chnroute).read_text().split('\n')))
|
|
nft_chnroute.extend([str(net) for net in bp if net.version == x])
|
|
nft_chnroute_rule = '\n'.join([(f"add element {'ip6' if x==6 else 'ip'} "
|
|
f"transparent_proxy{'_v6' if x==6 else ''} chnroute {{ {ipx} }}") for ipx in nft_chnroute]).encode('utf-8')
|
|
if subprocess.run(["nft", "-f", "-"], input=nft_chnroute_rule, check=False).returncode:
|
|
print("[!] nft chnroute failed")
|
|
elif args.action == 'down':
|
|
if os.getuid() != 0:
|
|
return invoke_self_with_sudo()
|
|
flush_iproute2()
|
|
flush_nft()
|
|
if args.stop_all:
|
|
stop_all_configs()
|
|
|
|
if __name__ == "__main__":
|
|
exit(main() or 0)
|