''' Test ismijnverweg geolookup api ''' import logging import random import re from ipaddress import ip_network from operator import itemgetter from unittest.mock import MagicMock, patch import geoip2.database from faker import Faker from fastapi.testclient import TestClient from main import app # type: ignore # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Faker for generating test data fake = Faker() # Create test client fake_ipv6 = fake.ipv6() client = TestClient(app, client=(fake_ipv6, 31337)) def gen_testdata(): ''' Generate some mocked up GeoIP2 City/ASN entries ''' continents = ('EU', 'NA', 'SA', 'AS', 'AU') asns = {} cities = {} # get me max 10 networks to create mocked up entries networks = list(filter(lambda network: (network.version == 4 and network.prefixlen < 32 and network.prefixlen >= 8) or (network.version == 6 and network.prefixlen <= 64 and network.prefixlen >= 56), (ip_network(fake.unique.ipv4_public(network=True) if random.random() < 0.25 else fake.unique.ipv6(network=True)) for _ in range(50))))[0:10] for network in networks: hostaddr = next(network.hosts()) logging.info('Using %s from %s', hostaddr, network) asns[hostaddr] = geoip2.models.ASN( hostaddr, network=network, autonomous_system_organization=fake.company(), autonomous_system_number=fake.random_number(5)) cities[hostaddr] = geoip2.models.City( locales=['en'], city={'names': {'en': fake.city()}}, country={'iso_code': fake.country_code(), 'names': {'en': fake.country()}}, continent={'code': random.choice(continents)}) return asns, cities def get_mock_reader(test_data): ''' Mock the geoip2.database.Reader ''' def _asn_lookup(ip): try: logging.info('Looking up ASN info for %s', ip) return test_data[0][ip] except KeyError as exc: raise geoip2.errors.AddressNotFoundError( f'{ip} not in test database') from exc def _city_lookup(ip): try: logging.info('Looking up City info for %s', ip) return test_data[1][ip] except KeyError as exc: raise geoip2.errors.AddressNotFoundError( f'{ip} not in test database') from exc mock_reader = MagicMock() mock_reader_ctx = MagicMock() mock_reader_ctx.test_data = test_data mock_reader_ctx.asn = _asn_lookup mock_reader_ctx.city = _city_lookup mock_reader.__enter__ = lambda _: mock_reader_ctx return mock_reader def test_no_query(): """Test searching without a query parameter""" test_data = gen_testdata() with patch('geoip2.database.Reader', return_value=get_mock_reader(test_data)): response = client.get("/") assert response.status_code == 200 results = response.json() logging.info(results) assert results['ip'] == fake_ipv6 assert len(results) > 0 def test_single_query(): """Test searching with an ip address""" test_data = gen_testdata() with patch('geoip2.database.Reader', return_value=get_mock_reader(test_data)): fake_ipv4 = fake.ipv4_public() response = client.get(f"/{fake_ipv4}") assert response.status_code == 200 results = response.json() logging.info(results) assert results['ip'] == fake_ipv4 assert len(results) > 0 def test_multi_query(): """Test searching with an ip address""" test_data = gen_testdata() with patch('geoip2.database.Reader', return_value=get_mock_reader(test_data)): fake_ips = [{'ip': fake.ipv6() if random.random() > 0.5 else fake.ipv4()} for _ in range(16)] response = client.post("/", json=fake_ips) assert response.status_code == 200 results = response.json() logging.info(results) for ip in map(itemgetter('ip'), results): assert ip in map(itemgetter('ip'), fake_ips) assert len(results) > 0 def test_invalid_query(): """Test searching with an invalid ip address""" test_data = gen_testdata() with patch('geoip2.database.Reader', return_value=get_mock_reader(test_data)): invalid_ip = '500.312.77.31337' test_pattern = 'Input is not a valid IPv[46] address' response = client.get(f"/{invalid_ip}") assert response.status_code == 422 results = response.json() logging.info(results) assert all(map(lambda x: x == invalid_ip, ( map(itemgetter('input'), results['detail'])))) assert all(map(lambda x: re.match(test_pattern, x), ( map(itemgetter('msg'), results['detail'])))) assert len(results) > 0 if __name__ == "__main__": # Run tests test_no_query() test_single_query() test_invalid_query() test_multi_query() print("All tests passed!")