
##  IMC'16 Invalid Certificate Study 
#   This module link certificates by iteratively conducting our linking methodolog (see multi_filter.py)
#
#   Input:   (1) Sorted certificates by each field we're interested in (2) the list of field names of which we will iteratively conduct the linking process.
#   Output:  File path where the linked certificates will be stored.

import sys
import os
import json
import re2 as re
from datetime import datetime
from operator import itemgetter
import inspect

import time
import ast
import pickle

def isOverlap(hostlist1, hostlist2):
    if(len(hostlist1) == 0 or len(hostlist2) == 0):
        return False
    hostlist1_start = hostlist1[0][0]
    hostlist1_end = hostlist1[-1][0]
    
    hostlist2_start = hostlist2[0][0]
    hostlist2_end = hostlist2[-1][0]
    
    #print hostlist1_start, hostlist1_end, hostlist2_start, hostlist2_end
    if(hostlist2_start >=  hostlist1_end ):
        return False
    else:
        return True

two_scans = [20120610, 20131111, 20131209, 20131216, 20131224, 20140106, 20140113, 20140120, 20140129]

def overlapTime(rids_hostlist): # cert = {'cert1' : hostlist}
    scan = {}
    sorted_rids_hostlist = sorted(rids_hostlist.items(), key = itemgetter(1)) # sort by first advertised date
    last_hostlist = sorted_rids_hostlist[0][1]
   
    scans = {}
    for rid, hostlist in sorted_rids_hostlist[1:]:
        if(len(hostlist) == 0):
            continue

        if(len(hostlist) == 1):
            date = hostlist[0][0]
            if(date not in two_scans):
                scans[date] = scans.get(date, 0) + 1
                if(scans[date] > 2):
                    return True
            continue

        if(isOverlap(last_hostlist, hostlist)):
            return True

        last_hostlist = hostlist
    return False


def strToList(a):
    c = a.replace(",[", "+[") ## 1.
    c = c.replace("),(", ")+(") ## 2.

    c = c.replace("])", "\"])")
    c = c.replace(",", "\",\"")


    c = c.replace("+[", ",[\"") ## recover 1.
    c = c.replace(")+(", "),(") ## recover 2.

    c = c.replace("\n", "")
    c = c.replace(" ", "")

    c = c.replace("\",\"", ",")

    try:
        result = ast.literal_eval(c)
    except SyntaxError:
        result = []

    return result 

def getdate(hlist):
    return re.findall(r"[0-9]{8}", hlist)

def getDiffDate(datelist):
    if(len(datelist) == 0):
        return 0
    elif(len(datelist) == 1):
        return 1

    k =  map(int, datelist)
    mink, maxk =  min(k), max(k)

    strp_mink  = datetime.strptime(str(mink), "%Y%m%d")
    strp_maxk  = datetime.strptime(str(maxk), "%Y%m%d")
    return (strp_maxk - strp_mink).days

def getIPs_v0(hlist):
    return re.findall(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", hlist)

def getIPs_v1(hlist):
    nums = re.findall(r", [0-9]+\)", hlist)
    if(len(nums) == 0):
        return [0]
    return map(lambda v: int(v[1:-1]), nums)

def getUniqueNumIPs(IPs):
    return len(set(IPs))

def getASdistribution(hlist):
    as_dict = {}
    for asNum in map(lambda v: v[1:-2], re.findall("/[0-9]+']", hlist)):
        if(asNum not in as_dict):
            as_dict[asNum] = 0
        as_dict[asNum] += 1
    return json.dumps(as_dict)

def log2hostlist(isValid, public_key, num_days, cert_life, uniqueIPs, numIPs, ip, as_dist):
    output = "\t".join([isValid, public_key, str(num_days), str(cert_life),str( uniqueIPs), str(numIPs), ip, as_dist])
    w.write(output + "\n")
    
def check_hostlist( isValid, public_key, hostlist):
    isValid = isValid.rstrip()
    days = getdate(hostlist)
    life = getDiffDate(days)
    IPs_v0 = getIPs_v0(hostlist)
    IPs_v1 = getIPs_v1(hostlist)
    as_dist = getASdistribution(hostlist)
    
    uniqueIPs = getUniqueNumIPs(IPs_v0)
    uniqueIPs2 = max(IPs_v1)
    num_uniqueIPs = max([uniqueIPs2, uniqueIPs])
    ip = ""
    if(num_uniqueIPs == 1):
        ip = IPs_v0[0]
    
    #log2hostlist( isValid, public_key, len(days), life, num_uniqueIPs, len(IPs_v0) + sum(IPs_v1), ip, as_dist)
    
    if(len(days) == 0):
        return False
    else:
        return ( len(IPs_v0) + sum(IPs_v1) ) / len(days) < 2

def doTestDevice(sorted_certificate,\
        index_prop,\
        same_device_meta_path,\
        same_device_certificates_path,\
        remainder_certificate_path,\
        discriminator):
    ## input -> sorted_invalid_certificate 
    ## 1. no empty hostlist

    print '%s, %s' % (time.ctime(), inspect.stack()[0][3])

    two_hosts = 0
    same_device_hosts = 0
    saved_line_cnt = 0
    diff_device_hosts = 0

    rids_hostlist = {}
    w_same_device = open(same_device_meta_path, "w")
    w_same_certificates = open(same_device_certificates_path, "w")
    w_remainder_certificates = open(remainder_certificate_path, "w")
    num = 0
    prev_prop = ""

    saved_line = ""
    cnt = 0
    cert_cnt = 0
    t = time.time()
    for line in open(sorted_certificate, "r"):
        cnt += 1
        if(cnt % 100000 == 0):
            print cnt/900000, "%", cnt,  time.time() - t
            t = time.time()

        if line == "":
            continue

        lines = line.split("\t")

        try:
            rid = lines[0]
            if(type(index_prop) == int):
                prop = lines[index_prop]
            else:
                prop = (lines[index_prop[0]], lines[index_prop[1]])
            host_list = strToList(lines[-2])
            isValid = lines[-1]

        except IndexError:
            print line
            continue
    
        if(not check_hostlist(isValid, prop, str(host_list))):
            two_hosts += 1
            continue

        if(prev_prop != prop):# and len(rids_hostlist) > 0):
            if(len(rids_hostlist) > 1): ## multiple certs sharing same public keys
                if(not discriminator(rids_hostlist)): ### which means same device
                    w_same_device.write("%s\n" % "\n".join(rids_hostlist.keys()))
                    w_same_certificates.write(saved_line)

                    same_device_hosts += saved_line_cnt
                    num += 1
                else:
                    w_remainder_certificates.write(saved_line)
                    diff_device_hosts += saved_line_cnt

            else:
                w_remainder_certificates.write(saved_line)
                diff_device_hosts += saved_line_cnt

            rids_hostlist = {}

            saved_line = line 
            saved_line_cnt = 1 
        else: ## if subsequent certificates have same public keys
            saved_line += line 
            saved_line_cnt += 1
            pass

        rids_hostlist[rid] = host_list
        prev_prop = prop

    if(len(rids_hostlist) >= 1): 
        if(not discriminator(rids_hostlist)):
            w_same_device.write("%s\n" % "\n".join(rids_hostlist.keys()))
            w_same_certificates.write(saved_line)
            same_device_hosts += saved_line_cnt

            num += 1
        else:
            w_remainder_certificates.write(saved_line)
            diff_device_hosts += saved_line_cnt
    else:
        w_remainder_certificates.write(saved_line)
        diff_device_hosts += saved_line_cnt

    w_same_device.close()
    w_same_certificates.close()
    w_remainder_certificates.close()

    print cert_cnt, two_hosts, cert_cnt - two_hosts
    print same_device_hosts, diff_device_hosts, same_device_hosts + diff_device_hosts

def getIndexProp(f):
    fname = {
    'pk':19 - 1,
    'cn':4 - 1,
    'san':5 - 1,
    'dn':6 - 1,
    'nb':9 - 1,
    'na':10 - 1,
    'oids':15 - 1,
    'crls':16 - 1,
    'ocsp':17 - 1,
    'aias':18 - 1,
    'issuer-sn':[7 - 1, 8 - 1]
    }

    for prop in fname:
        if( prop in f ):
            return fname[prop]

def sort_certificate(remainder_path, index_prop):
    sort = "sort -t$'\\t' -k%s,%s %s > /tmp/tmp.sort" % (index_prop + 1, index_prop + 1, remainder_path)
    print sort
    os.system(sort)

    rename = "mv /tmp/tmp.sort %s" % remainder_path
    print rename
    os.system(rename)
    
if __name__ == "__main__":
    # CN -> PK -> OCSP -> CRLs -> AIAs -> SANs -> OIDs
    order = ['cn', 'pk', 'ocsp', 'crls', 'aias', 'san', 'oids']
    sorted_certificate = "/home/tjchung/research/certs/invalid/data/cascade/0.cn.remainder.tsv"
    discriminator = lambda rid_hostlist: overlapTime(rid_hostlist)
    
    for i in range(0, len(order)):
        index_prop = getIndexProp(order[i])
        print sorted_certificate, index_prop
        same_device_path = "cascade/%s.%s.tsv" % (i, order[i])
        same_device_meta_path = "cascade/%s.%s.meta.tsv" % (i, order[i])
        remainder_path = "cascade/%s.%s.remainder.tsv" % (i, order[i])
            
        print same_device_path
        print same_device_meta_path
        print remainder_path
    
        doTestDevice(sorted_certificate, \
                index_prop, \
                same_device_meta_path, \
                same_device_path, \
                remainder_path, \
                discriminator)

        if( i == len(order) - 1): ## we don't need to sort
            break

        index_prop = getIndexProp(order[i+1])
        sort_certificate(remainder_path, index_prop)
        sorted_certificate = remainder_path
    
    
