题意
给定一个序列$\{a_i\}$,求它的每个子区间的逆序对数和。包括长度为1的区间。 对于$60\%$的数据,$n \leq 1000$; 对于$100\%$的数据,$n \leq 1000000$。
思路
先看$60\%$的解法。
当时在考试的时候,我选择的是用树状数组暴力扫。先离散化。对于同一个左端点$i$,依次向右边扫右端点$j$,对于每个j都统计$[i,j-1]$有几个逆序对。换$i$的时候,清空树状数组,再继续扫。时间复杂度是$O(n^2 \log n)$,可以通过$60\%$的数据。
出题人给的解法是对于每个逆序对,统计它对答案的贡献。
如果我们找到了一个逆序对$(i,j)$,那么就有$i \cdot (n-j+1)$个区间包含了它。
于是答案就是$\large \sum^{n}{i=1} \sum^{n}{j=i} i \cdot (n-j+1)$。
复杂度为$O(n^2)$。
其实这里就已经体现了$100\%$解法的思想了;100分解法只是把上述过程用树状数组进行了优化。首先按照权值建立树状数组,把$a_i$作为下标,$i$作为值。对于一个位置$j$,和它有关的逆序对有$(i_1,j),(i_2,j), \dots , (i_k,j)$,于是我们就可以利用权值树状数组查出$i_1+i_2+\dots+i_k$,再与$(n-j+1)$相乘即为这些逆序对的所有贡献。答案是$O(n^4)$级别的(由暴力可知),要打高精度。高精乘单精即可,高精乘高精会TLE(当然你打FFT除外)。
总结
对于“所有子区间内xxx的和”一类问题,可以采用统计贡献的方法:对于每个所求元素分别统计它对于答案的贡献,再相加得到最终的答案。
代码
#include <bits/stdc++.h>
#define rint register int
using namespace std;
typedef unsigned long long ULL;
const int MAXN = 1e6;
struct BigInteger {
static const int MAXN = 20; //1e80
static const int MOD = 10000;
static const int WIDTH = 4;
int v[MAXN+10], sz;
void Clear() {
while(sz>1 && v[sz-1]==0) sz--;
}
BigInteger() {
sz=0; v[sz++]=0;
}
BigInteger operator=(const string& str) {
//1|2345|6789
int p,start,end,len;
len=str.length(); sz=0;
p=len/WIDTH + ((bool)(len%WIDTH));
for(rint i=0; i<p; i++) {
end=len-i*WIDTH;
start=max(0, len-(i+1)*WIDTH);
v[sz++]=atoi(str.substr(start,end-start).c_str());
}
Clear();
return *this;
}
BigInteger operator=(unsigned long long x) {
sz=0;
do {
v[sz++] = x%MOD;
x/=MOD;
} while(x!=0);
Clear();
return *this;
}
BigInteger(const string& str) {
sz=0; *this=str;
}
BigInteger(const unsigned long long x) {
sz=0; *this=x;
}
bool operator<(const BigInteger& b) const {
if(sz<b.sz) return true;
if(sz>b.sz) return false;
for(rint i=sz-1; i>=0; i--) {
if(v[i]<b.v[i]) return true;
if(v[i]>b.v[i]) return false;
if(v[i]==b.v[i]) continue;
}
return false;
}
bool operator>(const BigInteger& b) const {
return b<*this;
}
bool operator==(const BigInteger& b) const {
return !(b<*this) && !(*this<b);
}
bool operator<=(const BigInteger& b) const {
return *this<b || *this==b;
}
bool operator>=(const BigInteger& b) const {
return *this>b || *this==b;
}
BigInteger operator+(const BigInteger& b) const {
BigInteger c;
c.sz=0;
for(rint i=0,g=0,x=0; ; i++) {
x=g;
if(g==0 && i>=sz && i>=b.sz) break;
if(i<sz) x+=v[i];
if(i<b.sz) x+=b.v[i];
c.v[c.sz++]=x%MOD;
g=x/MOD;
}
c.Clear();
return c;
}
BigInteger operator+=(const BigInteger& b) {
return *this=*this+b;
}
BigInteger operator-(const BigInteger& b) const {
BigInteger c;
c.sz=0;
for(rint i=0,g=0,x=0; ; i++) {
x=g;
if(g==0 && i>=sz && i>=b.sz) break;
if(i<sz) x+=v[i];
if(i<b.sz) x-=b.v[i];
c.v[c.sz++]=(x%MOD+MOD)%MOD;
g=g/MOD-(x<0);
}
c.Clear();
return c;
}
BigInteger operator-=(const BigInteger& b) {
return *this=*this-b;
}
BigInteger operator*(int b) const {
ULL g=0,x=0;
BigInteger c; c.sz=0;
for(rint i=0; ; i++) {
x=g;
if(g==0 && i>=sz) break;
if(i<sz) x+=1ull*v[i]*b;
c.v[c.sz++]=x%MOD;
g=x/MOD;
}
c.Clear();
return c;
}
friend istream& operator>>(istream& in, BigInteger& b) {
string str;
if(!(in>>str)) return in;
b=str; return in;
}
friend ostream& operator<<(ostream& out, const BigInteger& b) {
out<<b.v[b.sz-1];
static char t[MOD+10];
for(rint i=b.sz-2; i>=0; i--) {
sprintf(t,"%04d",b.v[i]);
out<<t;
}
return out;
}
};
template<typename T>
struct BIT {
int N; T A[MAXN+10];
inline int lowbit(int x) { return x&(-x); }
void add(int p, T val) {
while(p<=N) {
A[p]+=val;
p+=lowbit(p);
}
}
T query(int p) {
T ret=0;
while(p>0) {
ret+=A[p];
p-=lowbit(p);
}
return ret;
}
};
int N, A[MAXN+10], Num[MAXN+10];
BIT<ULL> Bit;
inline int readint() {
rint f=1,r=0; register char c=getchar();
while(!isdigit(c)) { if(c=='-')f=-1; c=getchar(); }
while(isdigit(c)) { r=r*10+c-'0'; c=getchar(); }
return f*r;
}
void Init() {
scanf("%d", &N);
for(rint i=1; i<=N; i++) A[i]=readint();
//离散化
for(rint i=1; i<=N; i++) Num[++Num[0]]=A[i];
sort(Num+1, Num+Num[0]+1);
Num[0]=unique(Num+1, Num+Num[0]+1) - &Num[1];
for(rint i=1; i<=N; i++) A[i]=lower_bound(Num+1,Num+Num[0]+1,A[i])-&Num[0];
}
void Work() {
BigInteger Ans=0; Bit.N = Num[0];
for(rint i=1; i<=N; i++) {
Ans+=BigInteger(Bit.query(Num[0])-Bit.query(A[i])) * (N-i+1);
Bit.add(A[i],i);
}
cout<<Ans;
}
int main() {
Init(); Work();
return 0;
}