Skip to content
Snippets Groups Projects
Commit e5d806d6 authored by Popal, Massi's avatar Popal, Massi
Browse files

Upload New File

parent bafa115d
No related branches found
No related tags found
No related merge requests found
gbac.py 0 → 100644
#!/usr/bin/python3
import numpy as np
import networkx as nx
import pandas as pd
from typing import List, Tuple, Dict
import argparse
import logging
from collections import Counter, defaultdict
from ipaddress import ip_network, ip_address, ip_address as validate_ip
import random
import requests
import json
from flask import Flask, request, jsonify
import ssl
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def generate_dynamic_whitelist(alerts: np.ndarray, min_connections: int = 2) -> List[Tuple[str, str, str]]:
"""Generate a whitelist of destination IP, port, and protocol tuples based on connection frequency."""
dst_pairs = [(alert[3], str(alert[4]), alert[6]) for alert in alerts]
pair_counts = Counter(dst_pairs)
whitelist = [(dst_ip, dst_port, proto) for (dst_ip, dst_port, proto), count in pair_counts.items() if count >= min_connections]
logging.info(f"Generated dynamic whitelist with {len(whitelist)} entries: {whitelist[:10]}...")
return whitelist
def is_whitelisted(alert: np.ndarray, whitelist: List[Tuple[str, str, str]]) -> bool:
"""Check if an alert matches a whitelisted destination IP, port, and protocol tuple."""
src_ip, src_port, dst_ip, dst_port, proto = alert[1], alert[2], alert[3], alert[4], alert[6]
return (dst_ip, str(dst_port), proto) in whitelist
def validate_alert(alert):
"""Validate alert data for correct IP addresses, port ranges, and timestamp format."""
try:
validate_ip(alert[1])
validate_ip(alert[3])
src_port, dst_port = int(alert[2]), int(alert[4])
if not (0 <= src_port <= 65535 and 0 <= dst_port <= 65535):
raise ValueError("Port number out of valid range")
float(alert[5])
return True
except (ValueError, TypeError) as e:
logging.warning(f"Invalid alert skipped: {alert}, error: {e}")
return False
def parse_log_file(log_file: str, log_type: str = 'conn', time_window: float = 60.0, network_id: str = None) -> np.ndarray:
"""Parse a log file into an array of alerts, filtering based on suspicious behavior."""
logging.info(f"Parsing {log_type} file: {log_file} for network {network_id}")
if log_file.endswith('.log'):
df = pd.read_csv(log_file, sep='\t', comment='#', header=None)
if log_type == 'conn':
names = ['ts', 'uid', 'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p', 'proto', 'service',
'duration', 'orig_bytes', 'resp_bytes', 'conn_state', 'local_orig', 'local_resp',
'missed_bytes', 'history', 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes',
'tunnel_parents', 'ip_proto']
df.columns = names
required_fields = ['uid', 'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p', 'ts', 'proto', 'conn_state']
df.replace('-', np.nan, inplace=True)
df = df.infer_objects(copy=False)
df = df.dropna(subset=['id.orig_h', 'ts'])
alerts = df[required_fields].to_numpy()
valid_alerts = [alert for alert in alerts if validate_alert(alert)]
if len(valid_alerts) < len(alerts):
logging.warning(f"Dropped {len(alerts) - len(valid_alerts)} invalid alerts")
whitelist = generate_dynamic_whitelist(valid_alerts)
src_timestamps = defaultdict(list)
for alert in valid_alerts:
src_timestamps[alert[1]].append(float(alert[5]))
src_counts = Counter([alert[1] for alert in valid_alerts])
filtered_alerts = []
total_alerts = len(valid_alerts)
for i, alert in enumerate(valid_alerts):
if i % 50000 == 0:
logging.info(f"Filtering progress: {i}/{total_alerts} alerts processed")
src = alert[1]
count = src_counts[src]
relevant_ts = src_timestamps[src]
is_suspicious = (
(log_type == 'conn' and str(alert[7]) in ['REJ', 'S0'] and count > 20 and
(max(relevant_ts) - min(relevant_ts) <= time_window if len(relevant_ts) > 1 else True)) or
(log_type == 'conn' and str(alert[6]) == 'icmp' and count > 50) or
(log_type == 'conn' and str(alert[7]) == 'SF' and count > 10 and
(max(relevant_ts) - min(relevant_ts) <= time_window if len(relevant_ts) > 1 else True))
)
if is_suspicious or not is_whitelisted(alert, whitelist):
filtered_alerts.append(np.append(alert, network_id)) # Append network_id to each alert
logging.info(f"Parsed {total_alerts} alerts, {len(filtered_alerts)} remain after filtering")
return np.array(filtered_alerts)
def alert_similarity_graph(alerts: np.ndarray, similarity_threshold: float = 0.25, time_window: float = 60.0) -> nx.Graph:
"""Construct an undirected graph connecting alerts with sufficient similarity within a time window."""
G = nx.Graph()
alerts_sorted = sorted(alerts, key=lambda x: float(x[5]))
for i, alert_1 in enumerate(alerts_sorted):
uid_1 = alert_1[0]
G.add_node(uid_1)
j = i + 1
while j < len(alerts_sorted) and float(alerts_sorted[j][5]) - float(alert_1[5]) <= time_window:
alert_2 = alerts_sorted[j]
uid_2 = alert_2[0]
if uid_1 == uid_2:
j += 1
continue
G.add_node(uid_2)
sim = sum(0.25 for attr_index in range(1, 5) if alert_1[attr_index] == alert_2[attr_index])
if len(alert_1) > 7 and len(alert_2) > 7 and not pd.isna(alert_1[7]) and not pd.isna(alert_2[7]) and alert_1[7] == alert_2[7]:
sim += 0.25
if sim >= similarity_threshold:
G.add_edge(uid_1, uid_2, weight=sim)
j += 1
logging.info(f"Created similarity graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
return G
def netflow_graph(alerts: np.ndarray) -> nx.DiGraph:
"""Construct a directed graph representing network flow from source to destination IPs."""
G = nx.DiGraph()
for alert in alerts:
src, dst = alert[1], alert[3]
if src != dst:
G.add_node(src)
G.add_node(dst)
G.add_edge(src, dst)
logging.info(f"Created netflow graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
return G
def cluster_cliques(G: nx.Graph, k: int = 3) -> List[frozenset]:
"""Identify k-clique communities in the similarity graph."""
communities = list(nx.algorithms.community.k_clique_communities(G, k))
logging.info(f"Found {len(communities)} k-clique communities with k={k}")
return communities
def get_alerts_by_uid(alerts: np.ndarray, uid_list: List[str]) -> np.ndarray:
"""Extract alerts corresponding to a list of UIDs."""
uid_set = set(uid_list)
return np.array([alert for alert in alerts if alert[0] in uid_set])
def is_in_same_subnet(ip1: str, ip2: str, subnet: str = "147.32.84.0/24") -> bool:
"""Determine if two IP addresses belong to the same subnet."""
net = ip_network(subnet, strict=False)
return ip_address(ip1) in net and ip_address(ip2) in net
def infer_label(directed_graph: nx.DiGraph, alerts: np.ndarray, time_window: float = 60.0,
exploit_threshold: int = 10, scan_threshold: int = 20, subnet: str = "147.32.84.0/24") -> List[Tuple[float, str, List[str], List[str]]]:
attackers = [node for node in directed_graph.nodes if directed_graph.out_degree(node) >= 1]
victims = [node for node in directed_graph.nodes if directed_graph.in_degree(node) >= 1]
V, A, T = len(directed_graph.nodes), len(attackers), len(victims)
if len(alerts) < 10:
return [(0.0, "no attack (insufficient alerts)", attackers, victims)]
oto = 1/3 * ((V-A)/(V-1) + (V-T)/(V-1) + (V-abs(A-T))/V) if V > 1 else 0
otm = 1/3 * ((V-A)/(V-1) + T/(V-1) + abs(A-T)/(V-2)) if V > 2 else 0
mto = 1/3 * (A/(V-1) + (V-T)/(V-1) + abs(A-T)/(V-2)) if V > 2 else 0
mtm = 1/3 * (A/V + T/V + (V-abs(A-T))/V) if V > 0 else 0
src_dst_pairs = [(alert[1], alert[3]) for alert in alerts]
pair_counts = Counter(src_dst_pairs)
timestamps = [float(alert[5]) for alert in alerts]
proto = [alert[6] for alert in alerts]
conn_state = [alert[7] if len(alert) > 7 else None for alert in alerts]
src_targets = Counter([alert[1] for alert in alerts])
iocs = []
# Portscan-Erkennung
for src, count in src_targets.items():
if count >= scan_threshold:
relevant_timestamps = [ts for alert, ts in zip(alerts, timestamps) if alert[1] == src]
relevant_states = [state for alert, state in zip(alerts, conn_state) if alert[1] == src]
if (max(relevant_timestamps) - min(relevant_timestamps) <= time_window and
any(state == 'S0' for state in relevant_states if state)):
specific_attackers = [src]
specific_victims = list(set(alert[3] for alert in alerts if alert[1] == src))
iocs.append((otm, "one-to-many (Reconnaissance - Portscan)", specific_attackers, specific_victims))
# Exploit-Erkennung
for (src, dst), count in pair_counts.items():
if count >= exploit_threshold:
relevant_timestamps = [ts for alert, ts in zip(alerts, timestamps) if alert[1] == src and alert[3] == dst]
relevant_states = [state for alert, state in zip(alerts, conn_state) if alert[1] == src and alert[3] == dst]
if (max(relevant_timestamps) - min(relevant_timestamps) <= time_window and
sum(1 for state in relevant_states if state == 'SF') >= exploit_threshold):
iocs.append((oto, "one-to-one (Exploitation)", [src], [dst]))
# DDoS-Erkennung
if A > 1 and T == 1 and A >= 5:
iocs.append((mto, "many-to-one (Coordinated Attack - DDoS)", attackers, victims))
# Worm-Erkennung
if A > 1 and T > 1 and A >= 5 and T >= 5:
if any(victim in attackers for victim in victims):
iocs.append((mtm, "many-to-many (Worm Propagation)", attackers, victims))
if not iocs:
return [(0.0, "no attack", attackers, victims)]
# False Positive Check
filtered_iocs = []
for certainty, pattern, atkrs, victs in iocs:
if random.random() < 0.1 and pattern != "one-to-many (Reconnaissance - Portscan)":
filtered_iocs.append((0.0, f"{pattern} (possible false positive)", atkrs, victs))
else:
filtered_iocs.append((certainty, pattern, atkrs, victs))
return filtered_iocs
def detect_iocs(alerts: np.ndarray, similarity_threshold: float = 0.25, clique_size: int = 3,
time_window: float = 60.0, exploit_threshold: int = 10, scan_threshold: int = 20,
subnet: str = "147.32.84.0/24") -> List[Tuple[float, str, np.ndarray, List[str], List[str]]]:
sim_graph = alert_similarity_graph(alerts, similarity_threshold, time_window)
cliques = cluster_cliques(sim_graph, clique_size)
iocs = []
for clique in cliques:
alerts_in_clique = get_alerts_by_uid(alerts, list(clique))
flow_graph = netflow_graph(alerts_in_clique)
detected_iocs = infer_label(flow_graph, alerts_in_clique, time_window, exploit_threshold, scan_threshold, subnet)
for certainty, pattern_name, attackers, victims in detected_iocs:
if certainty > 0:
iocs.append((certainty, pattern_name, alerts_in_clique, attackers, victims))
logging.info(f"Detected IoC: {pattern_name} with certainty {certainty:.2f}, {len(attackers)} attackers, {len(victims)} victims")
return iocs
def send_iocs_to_server(iocs, server_url, network_id):
"""Transmit detected IoCs to the central server via HTTPS."""
iocs_with_id = [{"certainty": c, "pattern": p, "alerts": a.tolist(), "attackers": atk, "victims": v,
"network_id": network_id, "timestamp": float(a[0][5])}
for c, p, a, atk, v in iocs]
try:
response = requests.post(f"{server_url}/submit_ioc", json=iocs_with_id, verify=False)
if response.status_code == 200:
logging.info(f"IoCs successfully sent to server (Network {network_id})")
else:
logging.error(f"Error sending IoCs: {response.status_code}")
except Exception as e:
logging.error(f"Connection error sending to server: {e}")
app = Flask(__name__)
@app.route('/alert', methods=['POST'])
def receive_alert():
"""Receive and log alerts from the server."""
alert_plus = request.json
if alert_plus["source_network"] == alert_plus["current_network"]:
logging.info(f"Warning: Multi-step attack detected in this network ({alert_plus['source_network']})!")
else:
logging.info(f"Warning: Multi-step attack detected in {alert_plus['source_network']}!")
logging.info(f"Attack type: {alert_plus['type']}")
logging.info(f"Attackers: {alert_plus['attackers']}")
logging.info(f"Victims: {alert_plus['victims']}")
logging.info(f"Recommendation: {alert_plus['recommendation']}")
return jsonify({"status": "received"}), 200
def main():
parser = argparse.ArgumentParser(description="Graph-based IoC Detection with Dynamic Whitelist")
parser.add_argument('--conn_files', nargs='+', type=str, help="Paths to network conn log files (e.g., conn1.log conn2.log)")
parser.add_argument('--network_ids', nargs='+', type=str, required=True, help="Unique identifiers for home networks (e.g., HomeNet1 HomeNet2)")
parser.add_argument('--server_url', type=str, default="https://localhost:443", help="URL of the central server")
parser.add_argument('--similarity_threshold', type=float, default=0.25, help="Similarity threshold for alert clustering")
parser.add_argument('--clique_size', type=int, default=3, help="Minimum clique size for community detection")
parser.add_argument('--time_window', type=float, default=60.0, help="Time window in seconds for detection")
parser.add_argument('--exploit_threshold', type=int, default=10, help="Threshold for exploitation connection count")
parser.add_argument('--scan_threshold', type=int, default=20, help="Threshold for scan target count")
parser.add_argument('--subnet', type=str, default="147.32.84.0/24", help="Subnet for LAN detection")
args = parser.parse_args()
if len(args.conn_files) != len(args.network_ids):
raise ValueError("Number of conn_files must match number of network_ids")
try:
alerts_list = []
for conn_file, network_id in zip(args.conn_files, args.network_ids):
alerts = parse_log_file(conn_file, log_type='conn', time_window=args.time_window, network_id=network_id)
alerts_list.append(alerts)
if not alerts_list:
raise ValueError("At least one conn_file must be provided")
alerts = np.concatenate(alerts_list) if len(alerts_list) > 1 else alerts_list[0]
iocs = detect_iocs(alerts, args.similarity_threshold, args.clique_size, args.time_window,
args.exploit_threshold, args.scan_threshold, args.subnet)
if not iocs:
print("No Indicators of Compromise (IoCs) detected.")
else:
for certainty, pattern, alerts_in_clique, attackers, victims in iocs:
print(f"Pattern: {pattern}, Certainty: {certainty:.2f}")
print(f"Attackers: {attackers}")
print(f"Victims: {victims}")
print(f"Number of alerts: {len(alerts_in_clique)}")
print("Sample alerts:")
for alert in alerts_in_clique[:5]:
extra_field = alert[7] if len(alert) > 7 else None
extra_label = "Query" if isinstance(extra_field, str) and extra_field not in ['REJ', 'S0', 'SF'] else "State"
network_id = alert[-1] # Network ID is the last field
print(f" UID: {alert[0]}, Src: {alert[1]}:{alert[2]}, Dst: {alert[3]}:{alert[4]}, TS: {alert[5]}, Proto: {alert[6]}, {extra_label}: {extra_field}, Network: {network_id}")
print("---")
if iocs:
for network_id in args.network_ids:
send_iocs_to_server(iocs, args.server_url, network_id)
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
cert_path = "cert.pem"
key_path = "key.pem"
if not (os.path.exists(cert_path) and os.path.exists(key_path)):
raise FileNotFoundError("SSL certificates (cert.pem, key.pem) not found. Generate them with OpenSSL.")
context.load_cert_chain(cert_path, key_path)
app.run(host='0.0.0.0', port=8080, ssl_context=context, threaded=False)
except Exception as e:
print(f"An error occurred: {e}")
logging.error(f"Error in main: {e}")
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment