"""
try:
    import psyco
    psyco.full()
except ImportError:
    print 'Psyco not installed, the program will just run slower'
"""

from fcgi import WSGIServer
import re

import MySQLdb
import MySQLdb.cursors

#for stop-watch
import time

con = None

ACTION_GET_COUNTRY = '0'
ACTION_GET_ALL_DATA = '1'
ACTION_IS_NEWYORKER = '2'
ACTION_SHOW_NEWYORKER_CACHE = '3'
ACTION_CLEAR_NEWYORKER_CACHE= '4'
ACTION_IS_AMERICAN = 'ACTION_IS_AMERICAN'

#to keep the last timestamp when the stopWatch function was invoked
STOPWATCH_TIMESTAMP=0

def stopWatch():
    #Use this to measure how long it took from one place to another
    global STOPWATCH_TIMESTAMP
    result = float(time.time() - STOPWATCH_TIMESTAMP)
    STOPWATCH_TIMESTAMP = time.time()
    return result

def getRangesForCity(regionName,cityName):
    '''Given a city name it will return a list of 2 element tuples
       each representing an ip range that belongs to the city'''
    sql = "select ip_from, ip_to from ipcountryregioncityisp where region = '%s' and city = '%s'" % (regionName,cityName)

    cursor = getDbCursor()
    cursor.execute(sql)
    rows = cursor.fetchall()

    results = []
    for x in rows:
        results.append((x['ip_from'],x['ip_to']))

    return results

MAX_COUNTRY_DATA_CACHE_SIZE=500
COUNTRY_DATA_CACHE = {}

def getCountryData(ip):
    global COUNTRY_DATA_CACHE
    if COUNTRY_DATA_CACHE.has_key(ip):
        #print "+country data hit"
        return COUNTRY_DATA_CACHE[ip]
    
    #print "getCountryData(%s)" % ip
    num_ip = ip2num(ip)
    sql = "select * from ipcountry where ip_from <= %d and ip_to >= %d limit 1;" % (num_ip, num_ip)
    cursor = getDbCursor()
    cursor.execute(sql)
    row = cursor.fetchone()

    if len(COUNTRY_DATA_CACHE) > MAX_COUNTRY_DATA_CACHE_SIZE:
        #print "Popping country data from cache"
        COUNTRY_DATA_CACHE.pop(COUNTRY_DATA_CACHE.iterkeys().next())
        COUNTRY_DATA_CACHE.pop(COUNTRY_DATA_CACHE.iterkeys().next())
        COUNTRY_DATA_CACHE.pop(COUNTRY_DATA_CACHE.iterkeys().next())        

    COUNTRY_DATA_CACHE[ip]=row

    return row

def getCountryRegionCityIspData(ip):
    num_ip = ip2num(ip)
    sql = "select * from ipcountryregioncityisp where ip_from <= %d and ip_to >= %d limit 1;" % (num_ip, num_ip)
    cursor = getDbCursor()
    cursor.execute(sql)
    row = cursor.fetchone()
    return row

def ip2num(ipString):
    if ipString is None:
        raise Exception("Invalid IP")

    try:
       octets = [octet.strip() for octet in ipString.split('.')]
    except Exception,e:
        raise e

    num = (int(octets[0])<<24) + (int(octets[1])<<16) + (int(octets[2])<<8) + int(octets[3])
    return num

def num2ip(numericIp):
    if numericIp is None:
        raise Exception("Invalid numeric IP. Must be an integer")
    return str(numericIp >> 24) + '.' + str((numericIp >> 16) & 255) + '.' + str((numericIp >> 8) & 255) + '.' + str(numericIp & 255)

MAX_CONNECTIONS = 40
CONNECTION_POOL = []
CURRENT_POOL_INDEX = 0
def getDbCursor():
    # Does a round robin of connection pooling.
    # If you work this out with a single connection
    # it will segfault on high load
    #
    # Works fine with MAX_CONNECTIONS=5 but will double the pool just in case
    global CONNECTION_POOL
    global MAX_CONNECTIONS
    global CURRENT_POOL_INDEX

    con = None

    if len(CONNECTION_POOL) <= MAX_CONNECTIONS:
        con = MySQLdb.connect(host='localhost',
                              user='ip2location',
                              passwd='ip2location',
                              db='ip2location',
                              use_unicode=True,
                              cursorclass=MySQLdb.cursors.DictCursor,
                              )
        CONNECTION_POOL.append(con)
        CURRENT_POOL_INDEX = len(CONNECTION_POOL)-1
        #print "Adding New Connection. Total %d" % (len(CONNECTION_POOL))
    else:
        CURRENT_POOL_INDEX = (CURRENT_POOL_INDEX + 1) % len(CONNECTION_POOL)
        con = CONNECTION_POOL[CURRENT_POOL_INDEX]
        #print "Reusing connection (%d,%s)" % (CURRENT_POOL_INDEX,str(hex(id(con))))

    cursor = con.cursor(MySQLdb.cursors.DictCursor)
    cursor.execute("SET AUTOCOMMIT=1;")
        
    return cursor

def loadIpCountryData():
    print "loadIpCountryData"
    #TODO: In the future, next time we run this, instead of executing each
    #instruction one at the time, will create an temp.sql and copy there first
    #then execute a system call to mysql and pass the SQL file instead, that
    #is probably faster than doing it from python

    global con #global reference to the connection
    sourceDataFile = "IPCountry.csv"
    cursor=getDbCursor()

    try:
        cursor.execute("DROP TABLE ipcountry;")
    except Exception, e:
   	    print "Nothing to drop"

    #DATA LOOKS LIKE:
    #"50331648","50331903","US","UNITED STATES"
    cursor.execute("""CREATE TABLE ipcountry (
                          ip_from INT(10) UNSIGNED ZEROFILL NOT NULL, 
                          ip_to   INT(10) UNSIGNED ZEROFILL NOT NULL, 
                          country_short CHAR(2), 
                          country_name VARCHAR(64)
                          ) ENGINE=INNODB;""")
    #cursor.close()

    #GO THROUGH FILE AND INSERT
    f = open(sourceDataFile,"rb")
    errored=False
    sql = ""

    for line in f:
        try:
            d = line.strip().replace('"','').split(",")
            sql = "INSERT INTO ipcountry VALUES ("+d[0]+","+d[1]+",'"+MySQLdb.escape_string(d[2])+"','"+MySQLdb.escape_string(d[3])+"');\n"
            #print sql
            cursor.execute(sql)
        except Exception,e:
            print "PROBLEM WITH",sql,e
            errored=True

    f.close()

    if errored:
        return

    #create indexes
    sql = "CREATE INDEX idx_ic_a ON ipcountry (ip_from);"
    cursor.execute(sql)

    sql = "CREATE INDEX idx_ic_b ON ipcountry (ip_to);"
    cursor.execute(sql)

    sql = "CREATE INDEX idx_ic_c ON ipcountry (ip_from, ip_to);"
    cursor.execute(sql)
   

def loadIpCountryRegionCityIspData():
    sourceDataFile = "IP-COUNTRY-REGION-CITY-ISP.CSV"
    global con #global reference to the connection
    cursor = getDbCursor()

    try:
        cursor.execute("DROP TABLE ipcountryregioncityisp;")
        print "Dropped the table"
    except Exception, e:
 	    print "Nothing to drop"

    #DATA LOOKS LIKE:
    cursor.execute("""CREATE TABLE ipcountryregioncityisp (
                          ip_from INT(10) UNSIGNED ZEROFILL NOT NULL, 
                          ip_to   INT(10) UNSIGNED ZEROFILL NOT NULL, 
                          country_short CHAR(2), 
                          country_name VARCHAR(64), 
                          region VARCHAR(128), 
                          city VARCHAR(128), 
                          isp_name VARCHAR(256)
                          ) ENGINE=INNODB;""")

    f = open(sourceDataFile,"rb")
    i=0
    line = f.readline()
    errored=False
    
    while line not in (None,""):
        d = line.strip().replace('"','').split(",")

        sql = u"INSERT INTO ipcountryregioncityisp VALUES (" + unicode(d[0],'latin') + u","+unicode(d[1],'latin')+u",'"+unicode(MySQLdb.escape_string(d[2]),'latin')+u"','"+unicode(MySQLdb.escape_string(d[3]),'latin')+u"','"+unicode(MySQLdb.escape_string(d[4]),'latin')+u"','"+unicode(MySQLdb.escape_string(d[5]),'latin')+u"','"+unicode(MySQLdb.escape_string(d[6]),'latin')+u"');"
        try:
            cursor.execute(sql)
            i+=1
            if i%100 == 0:
                print ".",
            if i%1000 == 0:
                print ""
                i=0
                
        except Exception,e:
            print "PROBLEM WITH","\n",e
            errored = True

        if errored:
            f.close()
            return

        line = f.readline()

    f.close()

    #create indexes
    sql = "CREATE INDEX idx_icrci_a ON ipcountryregioncityisp (ip_from);"
    cursor.execute(sql)

    sql = "CREATE INDEX idx_icrci_b ON ipcountryregioncityisp (ip_to);"
    cursor.execute(sql)

    sql = "CREATE INDEX idx_icrci_c ON ipcountryregioncityisp (ip_from, ip_to);"
    cursor.execute(sql)

    cursor.close()
    con.close()
    con = None

MAX_CACHED_NEW_YORKERS = 1000
CACHED_NEW_YORKERS = {} #IP:HITS
NEW_YORKERS = []
NEW_YORKERS_HIT_COUNT = 0

MAX_CACHED_AMERICANS = 1000
AMERICANS = []
CACHED_AMERICANS = {} #IP:HITS
AMERICANS_HIT_COUNT = 0

def loadNewYorkers():
    global NEW_YORKERS

    #call it twice to reset delta to 0 or as close as possible
    stopWatch()
    stopWatch()
    
    print "Loading New York IP Ranges"
    from newyorkData import newYorkIpRanges, brooklynIpRanges,\
         bronxIpRanges, queensIpRanges, longIslandCityIpRanges,\
         coronaIpRanges, statenIslandIpRanges,  forestHillsIpRanges,\
         sunnysideIpRanges, woodsideIpRanges, jacksonHeightsIpRanges,\
         jamaicaIpRanges, astoriaIpRanges, elmhurstIpRanges,\
         ridgewoodIpRanges, hobokenIpRanges, jerseyCityIpRanges,\
         unionCityIpRanges, westNewYorkIpRanges, edgewaterIpRanges
    
    print "(%f seconds)" % stopWatch()
    print "Appending Ranges..."
    print "(%f seconds)" % stopWatch()    

    NEW_YORKERS = newYorkIpRanges + brooklynIpRanges + bronxIpRanges +\
                  queensIpRanges + longIslandCityIpRanges +\
                  coronaIpRanges + statenIslandIpRanges +\
                  forestHillsIpRanges + sunnysideIpRanges +\
                  woodsideIpRanges + jacksonHeightsIpRanges +\
                  jamaicaIpRanges + astoriaIpRanges + elmhurstIpRanges +\
                  ridgewoodIpRanges + hobokenIpRanges + jerseyCityIpRanges +\
                  unionCityIpRanges + westNewYorkIpRanges + edgewaterIpRanges

                  
    print "Clearing memory from initial imports"

    del(newYorkIpRanges)
    del(brooklynIpRanges)
    del(bronxIpRanges)
    del(queensIpRanges)
    del(longIslandCityIpRanges)
    del(coronaIpRanges)
    del(statenIslandIpRanges)
    del(forestHillsIpRanges)
    del(sunnysideIpRanges)
    del(woodsideIpRanges)
    del(jacksonHeightsIpRanges)
    del(jamaicaIpRanges)
    del(astoriaIpRanges)
    del(elmhurstIpRanges)
    del(ridgewoodIpRanges)
    del(hobokenIpRanges)
    del(jerseyCityIpRanges)
    del(unionCityIpRanges)
    del(westNewYorkIpRanges)
    del(edgewaterIpRanges)
    
    print "(%f seconds)" % stopWatch()
    print "Sorting New York IP Ranges"
    NEW_YORKERS.sort()
    print "(%f seconds)" % stopWatch()

def loadAmericans():
    global AMERICANS
    #reset stop watch
    stopWatch()
    stopWatch()

    print "Loading American IP ranges..."

    f = open('us_based_ip_ranges.dat','r')
    f.seek(0,2)
    end = f.tell()
    f.seek(0,0)

    while (f.tell() < end):
        line = f.readline()

        if line is None or line is '':
            f.close()
            print "--"
            break

        ips = line.strip().split(':')
        if len(ips)==2:
            AMERICANS.append((long(ips[0]),long(ips[1])))

    print "Loaded %d american ip ranged (%f seconds)" % (len(AMERICANS),stopWatch())
    print "Sorting..."
    AMERICANS.sort()
    print "(%f seconds)" % stopWatch()

def isAmerican(ip):
    '''
    Checks if the given ip belongs to new york.
    All results are cached on a limited sized cache where
    ips with less hits are removed when the list is filled.
    '''
    global CACHED_AMERICANS
    global AMERICANS
    global AMERICANS_HIT_COUNT
    
    if CACHED_AMERICANS.has_key(ip):
        #Add another hit
        CACHED_AMERICANS[ip] += 1
        AMERICANS_HIT_COUNT += 1
        return True

    result = binarySearch(ip,AMERICANS)#,True) #Verbosity

    if result is not None:
        CACHED_AMERICANS[ip]=1
        print "Added an american to the cache",AMERICANS_HIT_COUNT
        tryPoppingUnpopularCachedAmerican()

    return result is not None
    

def sortDictByValue(d):
    """
    Python's dictionaries dont support sorting dictionaries
    by value, which is often needed when you're counting
    buckets.

    This will sort a dictionary's keys in ascending order
    taking in consideration the int values on each bucket.

    It will return a new dictionary object
    """
    items = d.items()
    #get list of tuples (value, key)
    items = [(v, k) for (k, v) in items] 
    items.sort() #sorts by first element on each tuple
    #revert back to list of tuples (key,value)
    items = [(v, k) for (k, v) in items]
    #return new dict
    result = {}
    for x in items:
        result[x[0]]=x[1]
    return result

def tryPoppingUnpopularCachedNewYorker():
    """
    We want to keep a limit on how many new york
    ips we want to cache so we don't take up all
    the memory.

    If the number of cached new york ips has reached
    a limit, this will remove the least
    popular ip on the cache.
    """
    global CACHED_NEW_YORKERS
    global NEW_YORKERS
    
    if len(CACHED_NEW_YORKERS) > MAX_CACHED_NEW_YORKERS:
        print "About to pop an unpopular new yorker out of the cache"
        CACHED_NEW_YORKERS = sortDictByValue(CACHED_NEW_YORKERS)
        first_key = CACHED_NEW_YORKERS.iterkeys()
        CACHED_NEW_YORKERS.pop(first_key.next())

def tryPoppingUnpopularCachedAmerican():
    """
    We want to keep a limit on how many new york
    ips we want to cache so we don't take up all
    the memory.

    If the number of cached new york ips has reached
    a limit, this will remove the least
    popular ip on the cache.
    """
    global CACHED_AMERICANS
    global AMERICANS
    
    if len(CACHED_AMERICANS) > MAX_CACHED_AMERICANS:
        print "About to pop an unpopular american out of the cache"
        CACHED_AMERICANS = sortDictByValue(CACHED_AMERICANS)
        first_key = CACHED_AMERICANS.iterkeys()
        CACHED_AMERICANS.pop(first_key.next())

    
def showNewYorkerCache():
    """
    Returns a string that shows the status of the CACHED_NEW_YORKERS
    """
    global CACHED_NEW_YORKERS
    #tryPoppingUnpopularCachedNewYorker()
    cache = None
    
    if len(CACHED_NEW_YORKERS)>1:
        cache = sortDictByValue(CACHED_NEW_YORKERS)
    else:
        return str(len(CACHED_NEW_YORKERS)) + " ips cached\n" + str(CACHED_NEW_YORKERS)

    print "Executing showNewYorkerCache()", cache

    result = "Total New Yorkers Found %d - %d ips in the cache\n" % (NEW_YORKERS_HIT_COUNT,len(CACHED_NEW_YORKERS))
    for k in cache:
        result += k + ' : ' + str(cache[k]) + '\n'

    return result

def clearNewYorkerCache():
    """
    Clears the CACHED_NEW_YORKERS variable
    """
    global CACHED_NEW_YORKERS

    if len(CACHED_NEW_YORKERS) == 0:
        return str({})
    
    keys = CACHED_NEW_YORKERS.keys()
    print "Got the keys", keys
    for k in keys:
        print "Popping ",k
        CACHED_NEW_YORKERS.pop(k)
        print "Popped"

    print "About to return", str(CACHED_NEW_YORKERS)

    return str(CACHED_NEW_YORKERS)

def binarySearch(ip,LIST,verbose=False):
    '''
    Performs binary search of the Ip on the
    LIST list.
    '''
    stopWatch()
    stopWatch()

    start=0
    end=len(LIST)-1
    middle=(start+end)/2
    numeric_ip = ip2num(ip)

    TO_THE_LEFT=0
    IN_RANGE=1
    TO_THE_RIGHT=2

    def ipInRange(ip,ipRange):
        #compare with the borders first
        if ip < ipRange[0]:
            return TO_THE_LEFT

        if ip > ipRange[1]:
            return TO_THE_RIGHT

        if ip >= ipRange[0] and ip <= ipRange[1]:
            return IN_RANGE
        

        return ip >= ipRange[0] and ip <= ipRange[1]

    result = None
    iterations = 1    
    while start < end:
        iterations += 1

        if ipInRange(numeric_ip,LIST[start]) == IN_RANGE:
           result = LIST[start]
           break
        elif ipInRange(numeric_ip,LIST[end]) == IN_RANGE:
           result = LIST[end]
           break
        elif ipInRange(numeric_ip,LIST[middle]) == IN_RANGE:
           result = LIST[middle]
           break

        #We gotta move the pointers depending
        #on a comparison towards the middle

        whereToMove = ipInRange(numeric_ip,LIST[middle])

        if whereToMove == TO_THE_RIGHT:
           start=middle+1

           if verbose:
               print "\t>>>",

        elif whereToMove == TO_THE_LEFT:
           end=middle-1

           if verbose:
               print "\t<<<",

        elif whereToMove == IN_RANGE:
            result = LIST[middle]
            break
           
        middle = start + (end-start)/2

        if verbose:
            print " (%s:%s:%s)" % (start, middle, end)

    print "Binary search done (%f seconds)" % stopWatch()
    return result

def isNewYorker(ip):
    '''
    Checks if the given ip belongs to new york.
    The New York IP ranges are loaded lazily.
    All results are cached on a limited sized cache where
    ips with less hits are removed when the list is filled.
    '''
    global CACHED_NEW_YORKERS
    global NEW_YORKERS
    global NEW_YORKERS_HIT_COUNT
    
    if CACHED_NEW_YORKERS.has_key(ip):
        #Add another hit
        CACHED_NEW_YORKERS[ip] += 1
        NEW_YORKERS_HIT_COUNT += 1
        return True

    result = binarySearch(ip,NEW_YORKERS)#,True) #Verbosity


    if result is not None:
        CACHED_NEW_YORKERS[ip]=1
        print "Added a New Yorker to the cache",NEW_YORKERS_HIT_COUNT
        tryPoppingUnpopularCachedNewYorker()

    return result is not None


#Uncomment these lines if you want to re-populate the database from the .CSV files
#loadIpCountryData()
#loadIpCountryRegionCityIspData()
counter = 0

def analizePathInfo(env):
    '''
    Depending on the path given, it will
    return the following dict

    {"ip":<ip address to query>,
     "action":<ACTION_GET_COUNTRY|
               ACTION_GET_ALL_DATA|
               ACTION_IS_NEWYORKER>|
               ACTION_IS_AMERICAN|
               ACTION_SHOW_NEWYORKER_CACHE|
               ACTION_CLEAR_NEWYORKER_CACHE>}
    '''
    assert(env.has_key('PATH_INFO'))
    path_info = env['PATH_INFO']
    result = {'ip':env['REMOTE_ADDR'],
              'action':ACTION_GET_COUNTRY}

    #No parameters - Get country for remote ip
    if path_info == '/':
        return result

    #No action, but an ip is passed
    match = re.match('^/(\d*\.\d*\.\d*\.\d*)',path_info)
    if match is not None:
        result['ip'] = match.group(1)
        return result

    #An action and maybe an IP is passed.
    match = re.match('^/([a-zA-Z]*)(/)?(\d*\.\d*\.\d*\.\d*)?',path_info)

    if match is None:
        return result

    if match.group(1) is not None:
        #change action if its not country...
        if match.group(1) == 'all':
            result['action'] = ACTION_GET_ALL_DATA

        elif match.group(1) == 'isnewyorker':
            result['action'] = ACTION_IS_NEWYORKER

        elif match.group(1) == 'isamerican':
            result['action'] = ACTION_IS_AMERICAN

        elif match.group(1) == 'showcache':
            result['action'] = ACTION_SHOW_NEWYORKER_CACHE

        elif match.group(1) == 'clearcache':
            result['action'] = ACTION_CLEAR_NEWYORKER_CACHE

    if match.group(3) is not None:
        #print "Found IP:", match.group(3)
        #ladies and gentleman, we have an ip
        result['ip'] = match.group(3)

    return result

def ip2locationHandler(env, response):
    response('200 OK',[('Content-Type','text/plain'),
                               ('Connection','close')])
    global counter
    counter += 1
    #print "Counter",counter
    actionDict = analizePathInfo(env)

    assert(actionDict.has_key('ip'))
    ip = actionDict['ip']

    assert(actionDict.has_key('action'))
    action = actionDict['action']

    result = ""
    data = None

    global ACTION_IS_AMERICAN

    #/all
    if action == ACTION_GET_ALL_DATA:
        print "/all"
        data = getCountryRegionCityIspData(ip)
    #/isnewyorker
    elif action == ACTION_IS_NEWYORKER:
        #TODO: Add stopWatch value on a list
        #to calculate avg lookup time
        print "/isnewyorker"
        stopWatch()
        stopWatch()
        if isNewYorker(ip):
            queryTime = stopWatch()
            return '1 - %f' % (queryTime)
        queryTime = stopWatch()
        return '0 - %f' % (queryTime)
    elif action == ACTION_IS_AMERICAN:
        print "/isamerican"
        stopWatch()
        stopWatch()
        if isAmerican(ip):
            queryTime = stopWatch()
            return '1 - %f' % (queryTime)
        queryTime = stopWatch()
        return '0 - %f' % (queryTime)
    #/showcache
    elif action == ACTION_SHOW_NEWYORKER_CACHE:
        print "/showcache"
        return showNewYorkerCache()
    #/clearcache
    elif action == ACTION_CLEAR_NEWYORKER_CACHE:
        print "/clearcache"
        return clearNewYorkerCache()
    #/
    elif action == ACTION_GET_COUNTRY:
        data = getCountryData(ip)


    for x in data:
        if x not in ('ip_to','ip_from'):
            result += x + "=" + str(data[x]) + "|"

    result = result[:-1]

    #return [str(env).replace(',',',\n')]
    return [str(result)]

if __name__ == "__main__":
    loadNewYorkers()
    loadAmericans()
    WSGIServer(ip2locationHandler, bindAddress = ('127.0.0.1',9999)).run()
