from sympy import *
from sympy.abc import q,z
from math import log
from copy import deepcopy

def naive_det(a):
    a = ImmutableMatrix(a)
    size = shape(a)[0]
    if size <= 3:
        return det(a)
    d = 0
    for i in range(0,size):
        newa = (a.row_del(0)).col_del(i)
        d += (-1)**i*a[0,i]*naive_det(newa)
    return expand(d) 

def truncate_matrix(a,order):
    size = a.shape[0]
    b = zeros(size)
    for i in range(0,size):
        for j in range(0,size):
            b[i,j] = (a[i,j]+O(q**(order+1))).removeO()
    return b

def p_adic_valuation_int(i):
    if i == 0:
        return float('inf')
    v = 0
    while i % p == 0:
        i = i // p
        v += 1
    return v

def p_adic_valuation(r):
    if r == 0:
        return float('inf')
    a = numer(r)
    b = denom(r)
    return p_adic_valuation_int(a)-p_adic_valuation_int(b)

def matrix_p_adic_valuation(a):
    v = float('inf')
    size0 = a.shape[0]
    size1 = a.shape[1]
    for i in range(0,size0):
        for j in range(0,size1):
            if p_adic_valuation(a[i,j]) < v:
                v = p_adic_valuation(a[i,j])
    return v

def list_of_matrix_p_adic_valuation(l):
    print("computing list of p-adic valuations of",len(l),"terms", end="")
    v = []
    c = 0
    for a in l:
        c += 1
        if c % 10 == 0:
            print(".", end="")
        v.append(matrix_p_adic_valuation(a))
    print()
    return v

def truncate(f,o):
    return (f+O(z**(o+1))).removeO()

def coefficients(f,o):
    l = []
    for i in range(0,o+1):
        l.append(f.coeff(z,i))
    return l

def p_adic_valuation_at_pi(f):
    """
    f is a polynomial in q, with rational coefficients
    we set q = pi, and return the valuation of that p-adic number
    as a rational number
    """
    val = float("inf")
    if f == 0:
        return val
    for i in range(0,p-1):
        x = 0
        v = float("inf")
        for j in range(i,degree(f)+1,p-1):
            x += f.coeff(q,j)*(-p)**((j-i)//(p-1))
        if x != 0:
            v = p_adic_valuation(x)+Rational(i,p-1)
            if v < val:
                val = v
    return val
            
def dwork_exp(o):
    f = 1
    for i in range(1,o+1):
        f+=(z+z**p*Rational(1,p))**i*Rational(1,factorial(i))
    return truncate(expand(f),o)

def dwork_exp_coeff(o):
    """
    computes the Taylor coefficients of exp(x+x^p/p),
    following the recursion idea from Rodriguez-Villegas Experimental Number Theory
    """
    l = [Rational(1,1)]
    for i in range(1,o+1):
        x = l[i-1]
        if i-p >= 0:
            x += l[i-p]
        l.append(x/i)
    return l

def gamma_derivatives(i,val):
    """
    computes the power series of Gamma(x) following
    Rodriguez Villegas p. 156, or Cohen number theory ch. 11
    exactness is guaranteed up to and including p-adic valuation val
    """
    m = 1
    while (m < i*p/((p-1)*log(p)) + 1) or (i/(p-1)+m*(p-1)/p-log(i)/log(p) - i*log(m-1)/log(p) - (2*p-1)/p <= val):
        m += 1
    print("computing order",i,"Gamma derivative up to p-adic valuation",val, end="")
    print(", using",m,"terms in the Mahler expansion", end="")
    b = dwork_exp_coeff(p*m)
    d = 0
    zprod = 1
    for k in range(1,m+1):
        if k % 10 == 0:
            print(".", end="")
        zprod = expand(zprod*(z-k+1))
        new = (-p)**k*b[p*k]*zprod.coeff(z,i)*factorial(i)*Rational(1,-p)**i
        if p_adic_valuation(new) <= val:
            d += new
    print()
    return d

def gamma_derivative(val):
    """
    previous idea specialized to only the first derivative
    """
    m = 1
    while (m < p/((p-1)*log(p)) + 1) or (1/(p-1)+m*(p-1)/p - log(m-1)/log(p) - (2*p-1)/p <= val):
        m += 1
    print("computing first Gamma derivative up to p-adic valuation",val, end="")
    print(", using",m,"terms in the Mahler expansion", end="")
    b = dwork_exp_coeff(p*m)
    d = 0
    for k in range(1,m+1):
        if k % 10 == 0:
            print(".", end="")
        new = p**(k-1)*factorial(k-1)*b[k*p]
        if p_adic_valuation(new) <= val:
            d += new
    print()
    return d

def matrix_to_vector(a):
    """
    This converts a matrix into a column vector, concatenating the rows
    """
    size0 = a.shape[0]
    size1 = a.shape[1]
    v = zeros(size0*size1,1)
    for i in range(0,size0):
        for j in range(0,size1):
            v[i+size0*j,0] = a[i,j]
    return v

def vector_to_matrix(v):
    """
    This converts a column vector of size N^2 to a square matrix of size N
    """
    size = int(sqrt(v.shape[0]+0.01))
    a = zeros(size,size)
    for i in range(0,size):
        for j in range(0,size):
            a[i,j] = v[i+size*j,0]
    return a

def rhs_matrix(givena, givenphi, m):
    """
    This determines the right hand side of the equation
    \partial_q\Phi + A Phi - q^{p-1} Phi A(-q^p/p)  = 0
    At a given order q^{n-1}, n>=1. The givenphi starts with order 0,
    the givena with order -1 (which is not used here), causing an offset in the index
    Each givena and givenphi is a list of matrices with rational coefficients
    The outcome is a single matrix
    """
    size = givena[0].shape[0]
    rhs = zeros(size,size)
    for i in range(0,m):
        if m-i < len(givena):
            rhs -= givena[m-i]*givenphi[i]
        if (m-i) % p == 0:
            k = (m-p-i)//p
            if k < len(givena)-1:
                rhs -= givenphi[i]*givena[k+1]*((-Rational(p,1))**(-k))
    return rhs

def lhs_sample_matrix(givena,phi,m):
    """
    Here, we start with a proposal for phi_m (a matrix), and
    compute the left hand of the equation as a matrix
    """
    b = phi*m
    b += givena[0]*phi - p*phi*givena[0]
    return b

def lhs_matrix(givena,n):
    """
    This tries out all choices for phi_m, and returns the
    left hand side as a matrix of size N^2 \times N^2
    """
    size = givena[0].shape[0]
    lhs = zeros(size**2)
    for i in range(0,size):
        for j in range(0,size):
            phi = zeros(size,size)
            phi[j,i] = 1
            l = matrix_to_vector(lhs_sample_matrix(givena,phi,n))
            for k in range(0,len(l)):
                lhs[k,i*size+j] = l[k,0]
    return lhs

def solve_one_step(givena,givenphi,n):
    rhs_v = matrix_to_vector(rhs_matrix(givena,givenphi,n))
    inv = lhs_matrix(givena,n).inv()
    s = inv*rhs_v
    return vector_to_matrix(s)

def solve_n_steps(givena,phi0,n):
    print("computing q-series up to order", n, end="")
    phi = [phi0]
    for m in range(1,n+1):
        if m % 10 == 0:
            print(".", end="")
        new = solve_one_step(givena,phi,m)
        phi.append(new)
    print()
    return phi

def form_series(givena):
    size = givena[0].shape[0]
    c = zeros(size)
    for i in range(0,len(givena)):
        c += q**i*givena[i]
    return c

def check_frobenius(givena,givenphi,o):
    """
    This checks whether the Frobenius intertwining property is satisfied
    """
    quantum = form_series(givena)
    fquantum = form_series(givena).subs(q,(q**p)/(-p))
    frob = form_series(givenphi)
    intertwine = q*diff(frob,q) + quantum*frob + (-p)*frob*fquantum
    return truncate_matrix(expand(intertwine),o)

def transform_into_plot(v,val):
    """
    v is a list of valuations. We remove inf, reverse signs, etc.
    we also remove any valuations that are larger than the promised
    accuracy
    """
    x = []
    y = []
    for i in range(0,len(v)):
        if v[i] != float("inf"):
            if v[i] <= val: 
                x.append(i)
                y.append(-v[i])
    return [x,y]

def cp1(o,val):
    print("CP1")
    print()
    """
    returns a list of the valuations of the Frobenius intertwiner
    up to an including q^o, and only up to an error of val in the
    valuation
    """
    apole = Matrix([[0,0],[2,0]])
    a0 = zeros(2)
    a1 = Matrix([[0,2],[0,0]])
    givena = [apole,a0,a1]
    print("quantum connection multiplied by q:")
    print()
    pprint(form_series(givena))
    print()
    deg = Matrix([[1,0],[0,Rational(1,p)]])
    alpha = eye(2)*deg
    beta = Matrix([[0,0],[1,0]])*deg
    salpha = solve_n_steps(givena,alpha,o)
    sbeta = solve_n_steps(givena,beta,o)
    malpha = min(list_of_matrix_p_adic_valuation(salpha))
    mbeta = min(list_of_matrix_p_adic_valuation(sbeta))
    m = -min(malpha,mbeta)
    gder = gamma_derivative(m+val)
    frob = zeros(2)
    s = []
    for i in range(0,o+1):
        frob += q**i*(salpha[i]+2*gder*sbeta[i])
    frob = expand(frob)
    charpoly = det(z*eye(2) - frob)
    l = []
    for i in range(0,3):
        c = charpoly.coeff(z,i)
        l.append(p_adic_valuation_at_pi(c))
    return l
       
def cp2(o,val):
    print("CP2")
    print()
    apole = Matrix([[0,0,0],[3,0,0],[0,3,0]])
    a0 = zeros(3)
    a1 = zeros(3)
    a2 = Matrix([[0,0,3],[0,0,0],[0,0,0]])
    givena = [apole,a0,a1,a2]
    print("Quantum connection multiplied by q:")
    print()
    pprint(form_series(givena))
    print()
    deg = Matrix([[1,0,0],[0,Rational(1,p),0],[0,0,Rational(1,p**2)]])
    alpha = eye(3)*deg
    beta = Matrix([[0,0,0],[1,0,0],[0,1,0]])*deg
    gamma = Matrix([[0,0,0],[0,0,0],[1,0,0]])*deg
    salpha = solve_n_steps(givena,alpha,o)
    sbeta = solve_n_steps(givena,beta,o)
    sgamma = solve_n_steps(givena,gamma,o)
    malpha = min(list_of_matrix_p_adic_valuation(salpha))
    mbeta = min(list_of_matrix_p_adic_valuation(sbeta))
    mgamma = min(list_of_matrix_p_adic_valuation(sgamma))
    m = -min(malpha,mbeta,mgamma)
    gder = gamma_derivative(m+val)
    frob = zeros(3)
    s = []
    for i in range(0,o+1):
        frob += q**i*(salpha[i]+3*gder*sbeta[i]+(gder**2)*Rational(9,2)*sgamma[i])
    frob = expand(frob)
    charpoly = det(z*eye(3) - frob)
    l = []
    for i in range(0,4):
        c = charpoly.coeff(z,i)
        l.append(p_adic_valuation_at_pi(c))
    return l

def f1(o,val):
    print("Hirzebruch surface F1")
    print()
    apole = Matrix([[0,0,0,0],[2,0,0,0],[3,0,0,0],[0,1,2,0]])
    a0 = Matrix([[0,0,0,0],[0,-1,1,0],[0,0,0,0],[0,0,0,0]])
    a1 = Matrix([[0,2,0,0],[0,0,0,0],[0,0,0,2],[0,0,0,0]])
    a2 = Matrix([[0,0,0,3],[0,0,0,0],[0,0,0,0],[0,0,0,0]])
    givena = [apole,a0,a1,a2]
    print("Quantum connection multiplied by q:")
    print()
    pprint(form_series(givena))
    print()
    deg = Matrix([[1,0,0,0],[0,Rational(1,p),0,0],[0,0,Rational(1,p),0],[0,0,0,Rational(1,p**2)]])
    alpha = eye(4)*deg
    beta = apole*deg
    gamma = Matrix([[0,0,0,0],[0,0,0,0],[0,0,0,0],[1,0,0,0]])*deg
    salpha = solve_n_steps(givena,alpha,o)
    sbeta = solve_n_steps(givena,beta,o)
    sgamma = solve_n_steps(givena,gamma,o)
    malpha = min(list_of_matrix_p_adic_valuation(salpha))
    mbeta = min(list_of_matrix_p_adic_valuation(sbeta))
    mgamma = min(list_of_matrix_p_adic_valuation(sgamma))
    m = -min(malpha,mbeta,mgamma)
    gder = gamma_derivative(m+val)
    frob = zeros(4)
    s = []
    for i in range(0,o+1):
        frob += q**i*(salpha[i]+gder*sbeta[i]+4*(gder**2)*sgamma[i])
    frob = expand(frob)
    charpoly = naive_det(expand(z*eye(4) - frob))
    l = []
    for i in range(0,5):
        print(".")
        c = charpoly.coeff(z,i)
        l.append(p_adic_valuation_at_pi(c))
    return l

def cubic_surface(o,val):
    print("(Part of) the quantum connection for the cubic surface")
    print()
    apole = Matrix([[0,0,0],[1,0,0],[0,3,0]])
    a0 = Matrix([[0,0,0],[0,9,0],[0,0,0]])
    a1 = Matrix([[0,108,0],[0,0,36],[0,0,0]])
    a2 = Matrix([[0,0,252],[0,0,0],[0,0,0]])
    givena = [apole,a0,a1,a2]
    print("Quantum connection multiplied by q:")
    print()
    pprint(form_series(givena))
    print()
    deg = Matrix([[1,0,0],[0,Rational(1,p),0],[0,0,Rational(1,p**2)]])
    alpha = eye(3)*deg
    beta = apole*deg
    gamma = Matrix([[0,0,0],[0,0,0],[1,0,0]])*deg
    salpha = solve_n_steps(givena,alpha,o)
    sbeta = solve_n_steps(givena,beta,o)
    sgamma = solve_n_steps(givena,gamma,o)
    malpha = min(list_of_matrix_p_adic_valuation(salpha))
    mbeta = min(list_of_matrix_p_adic_valuation(sbeta))
    mgamma = min(list_of_matrix_p_adic_valuation(sgamma))
    m = -min(malpha,mbeta,mgamma)
    gder = gamma_derivative(m+val)
    frob = zeros(3)
    s = []
    for i in range(0,o+1):
        frob += q**i*(salpha[i]+gder*sbeta[i]+(gder**2)*Rational(3,2)*sgamma[i])
        frob = expand(frob)
    charpoly = det(z*eye(3) - frob)
    l = []
    for i in range(0,4):
        c = charpoly.coeff(z,i)
        l.append(p_adic_valuation_at_pi(c))
    return l

def intersection_of_quadrics(o,val):
    print("Intersection of two quadrics in P^5")
    print()
    apole = 2*Matrix([[0,0,0,0],[1,0,0,0],[0,4,0,0],[0,0,1,0]])
    a0 = zeros(4)
    a1 = 2*Matrix([[0,4,0,0],[0,0,2,0],[0,0,0,4],[0,0,0,0]])
    a2 = zeros(4)
    a3 = 2*Matrix([[0,0,0,4],[0,0,0,0],[0,0,0,0],[0,0,0,0]])
    givena = [apole,a0,a1,a2,a3]
    print("Quantum connection multiplied by q:")
    print()
    pprint(form_series(givena))
    print()
    deg = Matrix([[1,0,0,0],[0,Rational(1,p),0,0],[0,0,Rational(1,p**2),0],[0,0,0,Rational(1,p**3)]])
    alpha = eye(4)*deg
    beta = Matrix([[0,0,0,0],[1,0,0,0],[0,4,0,0],[0,0,1,0]])*deg
    gamma = Matrix([[0,0,0,0],[0,0,0,0],[1,0,0,0],[0,1,0,0]])*deg
    delta = Matrix([[0,0,0,0],[0,0,0,0],[0,0,0,0],[1,0,0,0]])*deg
    salpha = solve_n_steps(givena,alpha,o)
    sbeta = solve_n_steps(givena,beta,o)
    sgamma = solve_n_steps(givena,gamma,o)
    sdelta = solve_n_steps(givena,delta,o)
    malpha = min(list_of_matrix_p_adic_valuation(salpha))
    mbeta = min(list_of_matrix_p_adic_valuation(sbeta))
    mgamma = min(list_of_matrix_p_adic_valuation(sgamma))
    mdelta = min(list_of_matrix_p_adic_valuation(sdelta))
    m = -min(malpha,mbeta,mgamma,mdelta)
    gder = gamma_derivative(m+val)
    gder3 = gamma_derivatives(3,m+val)
    frob = zeros(4)
    s = []
    for i in range(0,o+1):
        frob += q**i*(salpha[i]+2*gder*sbeta[i]+8*(gder**2)*sgamma[i]+(12*(gder**3)-Rational(20,3)*gder3)*sdelta[i])
    frob = expand(frob)
    charpoly = naive_det(expand(z*eye(4) - frob))
    l = []
    for i in range(0,5):
        print(".")
        c = charpoly.coeff(z,i)
        l.append(p_adic_valuation_at_pi(c))
    return l
    
    s = []
    print("computing the p-adic valuations of",o+1,"terms in the Frobenius", end="")
    for i in range(0,o+1):
        if i % 10 == 0:
            print(".", end="")
        new = salpha[i]+2*gder*sbeta[i]+8*(gder**2)*sgamma[i]+(12*(gder**3)-Rational(20,3)*gder3)*sdelta[i]
        s.append(matrix_p_adic_valuation(new))
    print()
    return transform_into_plot(s,val)

def twistor_space_small(o,val):
    print("Small summand of the quantum connection for the twistor space")
    print()
    apole = Matrix([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0]])
    a0 = zeros(4)
    a1 = Matrix([[0,4,0,0],[0,0,0,0],[0,0,0,4],[0,0,0,0]])
    givena = [apole,a0,a1]
    print("Quantum connection multiplied by q:")
    print()
    pprint(form_series(givena))
    print()
    deg = Matrix([[1,0,0,0],[0,Rational(1,p),0,0],[0,0,Rational(1,p**2),0],[0,0,0,Rational(1,p**3)]])
    alpha = eye(4)*deg
    beta = Matrix([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0]])*deg
    gamma = Matrix([[0,0,0,0],[0,0,0,0],[1,0,0,0],[0,1,0,0]])*deg
    delta = Matrix([[0,0,0,0],[0,0,0,0],[0,0,0,0],[1,0,0,0]])*deg
    salpha = solve_n_steps(givena,alpha,o)
    sbeta = solve_n_steps(givena,beta,o)
    sgamma = solve_n_steps(givena,gamma,o)
    sdelta = solve_n_steps(givena,delta,o)
    malpha = min(list_of_matrix_p_adic_valuation(salpha))
    mbeta = min(list_of_matrix_p_adic_valuation(sbeta))
    mgamma = min(list_of_matrix_p_adic_valuation(sgamma))
    mdelta = min(list_of_matrix_p_adic_valuation(sdelta))
    m = -min(malpha,mbeta,mgamma,mdelta)
    gder = gamma_derivative(m+val)
    gder3 = gamma_derivatives(3,m+val)
    frob = zeros(4)
    s = []
    for i in range(0,o+1):
        frob += q**i*(salpha[i]+gder*sbeta[i]+gder**2*Rational(1,2)*sgamma[i]+gder3*Rational(1,6)*sdelta[i])
    frob = expand(frob)
    charpoly = naive_det(expand(z*eye(4) - frob))
    l = []
    for i in range(0,5):
        print(".")
        c = charpoly.coeff(z,i)
        l.append(p_adic_valuation_at_pi(c))
    return l

global p
p = 3
val = 15
o = 50
print("Prime",p)
print("Computation is exact up to p-adic valuation",val)
print("Frobenius computed up to order",o,"in the variable q")
"""
we promise exactness of the computation up to p-adic valuation val. Higher val
means potentially more precision (even though only the first nonzero term counts),
but longer computation
"""

outcome = cp2(o,val)
print()
print("Newton polygon at q=pi")
print(outcome)
