import constants from 'app/constants'
import { parseCurrencyParameter } from 'common/util/currencies'
function logSumExp(...elements) {
  const max = Math.max(...elements)
  return max + Math.log(elements.map(e => Math.exp(e - max)).reduce((acc, value) => acc + value, 0))
}
function brentq(f, xa = 0, xb = 10000, xtol = 1e-8, rtol = 1e-8, iter = 100) {
  let xpre = xa,
    xcur = xb
  let xblk = 0.0,
    fpre = f(xpre),
    fcur = f(xcur)
  let fblk = 0.0,
    spre = 0.0,
    scur = 0.0,
    sbis,
    delta,
    stry,
    dpre,
    dblk

  // Check if f(xa) or f(xb) is already zero
  if (fpre === 0) return xpre
  if (fcur === 0) return xcur

  // Ensure the root is bracketed, i.e., f(xa) and f(xb) have opposite signs
  if (Math.sign(fpre) === Math.sign(fcur)) {
    throw new Error('The function values at the interval endpoints must have opposite signs.')
  }

  // Begin iterations
  for (let i = 0; i < iter; i++) {
    if (Math.sign(fpre) !== Math.sign(fcur)) {
      xblk = xpre
      fblk = fpre
      spre = scur = xcur - xpre
    }

    // Swap to make sure f(xcur) < f(xblk)
    if (Math.abs(fblk) < Math.abs(fcur)) {
      ;[xpre, xcur] = [xcur, xblk]
      ;[fpre, fcur] = [fcur, fblk]
      xblk = xpre
      fblk = fpre
    }

    delta = (xtol + rtol * Math.abs(xcur)) / 2
    sbis = (xblk - xcur) / 2

    if (fcur === 0 || Math.abs(sbis) < delta) {
      return xcur // Converged to the root
    }

    if (Math.abs(spre) > delta && Math.abs(fcur) < Math.abs(fpre)) {
      // Attempt interpolation or extrapolation
      if (xpre === xblk) {
        // Linear interpolation
        stry = (-fcur * (xcur - xpre)) / (fcur - fpre)
      } else {
        // Inverse quadratic interpolation
        dpre = (fpre - fcur) / (xpre - xcur)
        dblk = (fblk - fcur) / (xblk - xcur)
        stry = (-fcur * (fblk * dblk - fpre * dpre)) / (dblk * dpre * (fblk - fpre))
      }

      // Check if the interpolation/extrapolation is useful
      if (2 * Math.abs(stry) < Math.min(Math.abs(spre), 3 * Math.abs(sbis) - delta)) {
        spre = scur
        scur = stry
      } else {
        // Fall back to bisection
        spre = sbis
        scur = sbis
      }
    } else {
      // Fall back to bisection
      spre = sbis
      scur = sbis
    }

    xpre = xcur
    fpre = fcur
    if (Math.abs(scur) > delta) {
      xcur += scur
    } else {
      xcur += sbis > 0 ? delta : -delta
    }

    fcur = f(xcur)
  }

  throw new Error('Maximum iterations exceeded. Root not found.')
}

function lmsrPrice(question, outcome, shares, currencyId) {
  // TODO: please refactor this on the future:
  const liquidity = question.scoring_rule_metadata.lmsr[currencyId].liquidity_param
  const outcomeList = []
  const shareList = []
  const newShareList = []
  const activeOutcomes = question.outcomes.filter(o => !o.disabled)

  if (activeOutcomes.length > 0 && typeof activeOutcomes[0] === 'object') {
    outcomeList.push(...activeOutcomes)
  } else {
    throw 'Error while calculating lmsr price: Not valid outcomes for question'
  }

  for (const questionOutcome of outcomeList) {
    shareList.push(questionOutcome.shares[currencyId] / liquidity)
    if (questionOutcome.id === outcome.id) {
      newShareList.push((questionOutcome.shares[currencyId] + shares) / liquidity)
    } else {
      newShareList.push(questionOutcome.shares[currencyId] / liquidity)
    }
  }

  const newCost = liquidity * logSumExp(...newShareList)
  const oldCost = liquidity * logSumExp(...shareList)
  return newCost - oldCost
}

function lmsrShares(question, outcome, balance, selectedCurrency) {
  // The formula for two outcomes, when the transaction
  // is on outcome 1 is defined as:
  //
  // shares(c) = b * ln(e^(c/b) * (e^(q1/b) + e^(q2/b)) - e^(q2/b)) - q1
  //
  // where c is the money amount we want to pay
  // q1 is #shares on outcome 1
  // q2 is #shares on outcome 2
  // b is the liquidity param
  //
  // for more explanation, see this link on wolfram alpha
  // https://www.wolframalpha.com/input/?i=solve++c+%3D+b+*+ln(+e%5E((q1%2Bs)%2Fb)+%2B+e%5E(q2%2Fb)+)+-+b+*+ln(+e%5E(q1%2Fb)+%2B+e%5E(q2%2Fb)+)+for+s

  // FIXME: Handle Exponential Overflow on every Math.exp on this function.
  const currencyId = parseCurrencyParameter(selectedCurrency)
  const liquidity = question.scoring_rule_metadata.lmsr[currencyId].liquidity_param
  const outcomeList = []
  const activeOutcomes = question.outcomes.filter(o => !o.active)

  if (activeOutcomes.length > 0 && typeof activeOutcomes[0] === 'object') {
    outcomeList.push(...activeOutcomes)
  } else {
    throw 'Error while calculating lmsr shares: Not valid outcomes for question'
  }

  const sharesSumAll = outcomeList.reduce((acc, o) => {
    return acc + Math.exp(o.shares[currencyId] / liquidity)
  }, 0)

  const sharesSum = outcomeList
    .filter(o => o.id !== outcome.id)
    .reduce((acc, o) => {
      return acc + Math.exp(o.shares[currencyId] / liquidity)
    }, 0)

  const balanceExp = Math.exp(balance / liquidity)
  const shares = liquidity * Math.log(balanceExp * sharesSumAll - sharesSum) - outcome.shares[currencyId]
  return shares
}

const calculateAlpha = (tax, q_vector) => {
  /**
    Calculate the alpha value based on the provided equation:
    alpha = tax / (n * log(n))

    Parameters:
    q_vector (np.array): A NumPy array representing the quantities qi.

    Returns:
    float: The calculated alpha value.
   */
  const n = q_vector.length
  return tax / (n * Math.log(n))
}
const bq = (tax, q_vector) => {
  /**
    Calculate the b(q) value based on the provided equation:
    b(q) = alpha * sum(q_vector)

    Parameters:
    q_vector (np.array): A NumPy array representing the quantities qi.

    Returns:
    float: The calculated b(q) value.
   */
  const alpha = calculateAlpha(tax, q_vector)
  const sum_q = q_vector.reduce((a, b) => a + b, 0) // Sum of q_vector
  return alpha * sum_q
}

const amountFunction = (q_vector, outcomeIndex, shares, tax, position) => {
  const costDifferenceLong = shares => {
    const updatedQVector = [...q_vector] // Create a copy of q_vector
    updatedQVector[outcomeIndex] += shares
    return costFromShares(tax, updatedQVector) - costFromShares(tax, q_vector)
  }

  const costDifferenceShort = shares => {
    let updatedQVector = [...q_vector] // Create a copy of q_vector
    updatedQVector = updatedQVector.map((q, i) => (i === outcomeIndex ? q : q + shares))
    return costFromShares(tax, updatedQVector) - costFromShares(tax, q_vector)
  }

  let costDiff = position === 'l' ? costDifferenceLong : costDifferenceShort

  return costDiff(shares)
}

const costFromShares = (tax, q_vector) => {
  /**
    Calculate the cost function C(q) based on the provided equation:
    C(q) = b(q) * log(sum_i(exp(qi / b(q))))
    where b(q) = alpha * sum_i(qi)

    Parameters:
    q_vector (np.array): A NumPy array representing the quantities qi.
    alpha (float): The alpha parameter, a scalar value.

    Returns:
    float: The calculated cost function value.
   */
  const b = bq(tax, q_vector)
  const exps = q_vector.reduce((sum, q) => sum + Math.exp(q / b), 0) // Sum of exp(qi / b)
  return b * Math.log(exps) // Cost function calculation
}

function sharesFromCost(q_vector, outcomeIndex, cost, tax, position = 'l') {
  // Purchase shares for the given outcome and update the quantities using the provided cost.
  // It can choose a long position, which increases the outcome shares by betting in favor of this outcome,
  // or a short position, betting against the outcome.

  // Parameters:
  // q_vector (np.array): A NumPy array representing the quantities qi.
  // outcome (int): The index of the outcome for which to purchase shares.
  // cost (float): The cost to be spent on purchasing shares.
  // position (str): The position of the trader. Default is 'l' (long). Options are 'l' (long) and 's' (short).

  // Returns:
  // tuple: A tuple containing the updated q_vector and the number of shares purchased.

  const costDifferenceLong = shares => {
    const updatedQVector = [...q_vector] // Create a copy of q_vector
    updatedQVector[outcomeIndex] += shares
    return costFromShares(tax, updatedQVector) - costFromShares(tax, q_vector) - cost
  }

  const costDifferenceShort = shares => {
    let updatedQVector = [...q_vector] // Create a copy of q_vector
    updatedQVector = updatedQVector.map((q, i) => (i === outcomeIndex ? q : q + shares))
    return costFromShares(tax, updatedQVector) - costFromShares(tax, q_vector) - cost
  }

  let costDifference
  if (position === 'l') {
    costDifference = costDifferenceLong
  } else if (position === 's') {
    costDifference = costDifferenceShort
  } else {
    throw new Error("Invalid position. Must be 'l' (long) or 's' (short).")
  }

  const result = brentq(costDifference)

  if (!result) {
    throw new Error('Failed to find the number of shares for the given cost.')
  }

  const shares = result
  q_vector[outcomeIndex] += shares

  return shares
}

const sellingShares = (q_vector, outcome, shares, tax, position = 'l', method = 'ALWAYS_MOVING_FORWARD') => {
  /**
   * Calculate the new state of the market maker and the amount the trader needs to pay
   * for selling shares in an "always moving forward" scheme.
   * It can choose to sell a long position, which increases the others outcome shares you have not bet on,
   * or sell a short position, increasing the shares bet on outcome.
   *
   * Parameters:
   * q_vector (Array): An array representing the current quantities of shares.
   * outcome (int): The index of the outcome for which the trader wants to sell shares.
   * shares (float): The number of shares the trader wants to sell for the specified outcome.
   * position (string): The position of the trader. Default is 'l' (long). Options are 'l' (long) and 's' (short).
   *
   * Returns:
   * tuple: A tuple containing the new state of the market maker (Array) and the amount
   *        the trader needs to pay (float).
   */

  let new_q_vector
  let amount_received

  if (method === 'ALWAYS_MOVING_FORWARD') {
    new_q_vector = q_vector.map((qi, i) =>
      (position === 'l' && i !== outcome) || (position === 's' && i === outcome) ? qi + shares : qi
    )
    amount_received = -1 * (costFromShares(tax, new_q_vector) - costFromShares(tax, q_vector) - shares) 
  } else {
    new_q_vector = q_vector.map((qi, i) =>
      (position === 's' && i !== outcome) || (position === 'l' && i === outcome) ? qi - shares : qi
    ) 

    amount_received = -1 * (costFromShares(tax, new_q_vector) - costFromShares(tax, q_vector)) 
  }

  // Ensure valid 'position'
  if (!['l', 's'].includes(position)) {
    throw new Error("Invalid position. Must be 'l' (long) or 's' (short).")
  }

  return amount_received
}

function lslmsrAmount(question, outcome, sharesOffset, selectedCurrency, position) {
  // ! deprecated , please use costFromShares instead
  console.warn('! lslmsrAmount is deprecated , please use costFromShares instead')

  let shareList = []
  let newShareList = []
  const activeOutcomes = question.outcomes.filter(o => !o.disabled)
  const currencyId = parseCurrencyParameter(selectedCurrency)

  const tax = question.scoring_rule_metadata.lslmsr[currencyId].tax
  const alpha = tax / (activeOutcomes.length * Math.log(activeOutcomes.length))

  let shareSum = 0
  for (const questionOutcome of activeOutcomes) {
    const share = questionOutcome.shares[currencyId]
    shareList.push(share)
    shareSum += share

    if (position === constants.position.SHORT) {
      // ! short
      // ? we add to the array (activeOutcomes.length - 1) * sharesOffset
      newShareList.push(questionOutcome.id === outcome.id ? share : share + sharesOffset)
    } else {
      // * long
      // ? we add to the array only one (outcome) * sharesOffset
      newShareList.push(questionOutcome.id === outcome.id ? share + sharesOffset : share)
    }
  }

  // ? So we should always use the aggregated sum of sharesOffset added
  // ? incase of long this will be only 1 * sharesOffset
  // ? but for short we should add (activeOutcomes.length - 1) * sharesOffset
  // * (activeOutcomes.length - 1) just mean all the other outcomes except for the selected one

  const aggregatedShareAdded = sharesOffset
  // position === constants.position.SHORT ? sharesOffset * activeOutcomes.length - 1 : sharesOffset

  // Cache the common denominator for better performance
  const normalizedShareList = shareList.map(e => e / (alpha * shareSum))
  const normalizedNewShareList = newShareList.map(e => e / (alpha * (shareSum + aggregatedShareAdded)))

  // Calculate old and new costs
  const oldCost = alpha * shareSum * logSumExp(...normalizedShareList)
  const newCost = alpha * (shareSum + aggregatedShareAdded) * logSumExp(...normalizedNewShareList)

  return newCost - oldCost
}

function lslmsrShares(question, outcome, inputAmount, currencyId, minShares, maxShares, position) {
  // ! deprecated , please use sharesFromCost instead
  console.warn('! lslmsrAmount is deprecated , please use sharesFromCost instead')

  /**
   * * on this function we are trying to find the shares with the given inputAmount
   * * we do that by using the formula for the lslmsr scoring rule
   * * we try 10000 , each time we move the maxShare and the minShare
   * * closer to the unknown share (x) that we want to find.
   * * we stop when lslmsrAmount(...,minShare,...) == lslmsrAmount(...,maxShare,...)
   * *
   * */
  //
  // const tax = question.scoring_rule_metadata.lslmsr[currencyId].tax
  for (let i = 0; i < 10000; i++) {
    const amount0 = lslmsrAmount(question, outcome, minShares, currencyId, position)
    const amount1 = lslmsrAmount(question, outcome, maxShares, currencyId, position)
    if (amount1 == amount0) {
      return maxShares
    }
    const n = maxShares - ((amount1 - inputAmount) * (maxShares - minShares)) / (amount1 - amount0)
    minShares = maxShares
    maxShares = n
  }
  return minShares
}

export {
  lmsrPrice,
  lmsrShares,
  lslmsrAmount,
  lslmsrShares,
  sharesFromCost,
  brentq,
  costFromShares,
  amountFunction,
  sellingShares,
}
