神题。
扑通一声跪下来,千古神犇vfk。
早就听说了这个造计算机题。正好,前几天上洛谷的省选课,这个题目作为提答作业布置了下来。于是我就开始了我的愉快作死之旅啦~
自己xjb乱搞,搞了56pts,发现不会做了XD
于是参考了chrt的题解和vfleaking的slide,各种卡,终于拿到了95pts…
未完待续
工具
按照vfk的题解的说法,题目中给出的是神经网络,也就是说,计算中没有“修改变量”的操作,只有“输入→输出”的映射。这是不是有点像函数式编程呢?在试着写了前 3 个点和第 5 个点之后,我发现运算里面行号的处理非常麻烦,如果把行号 hard-code到代码里面,很难看,而且很难调试。不如用函数式的思想,把“节点”稍微包装一下,这样就比较好调试了。还有个问题,题目里面的 90 位小数,怎么实现?难道手写高精度?既然是提交答案题,就不一定要用 C++。python 自带高精度浮点数(decimal
模块),而且语法很方便,就用 python 啦。NOI Linux自带python。
如下包装了几个基本命令:
#!/usr/bin/env python3
from decimal import Decimal
import decimal
decimal.getcontext().prec = 90
line = 1
def PutLine(str):
print(str)
global line
line = line + 1
return line - 1
class Node:
lineno = 0
def __init__(self, lineno):
self.lineno = lineno
def out(self): # 输出
Node(PutLine('O {}'.format(self.lineno)))
def shl(self, d): # 左移d位
return Node(PutLine('< {} {}'.format(self.lineno, d)))
def shr(self, d): # 右移d位
return Node(PutLine('> {} {}'.format(self.lineno, d)))
def add(self, y): # 加上y节点
return Node(PutLine('+ {} {}'.format(self.lineno, y.lineno)))
def opposite(self): # 取相反数
return Node(PutLine('- {}'.format(self.lineno)))
def sigmoid(self): # sigmoid函数
return Node(PutLine('S {}'.format(self.lineno)))
def offset(self, c): # 偏移常数c
return Node(PutLine('C {} {:.90f}'.format(self.lineno, c)))
def readin(): # 输入
return Node(PutLine('I'))
这样调用就很方便了,如第一个点,这么写即可:
readin().add(readin()).shl(1).opposite().out()
从上面一行可以直接地看出用了 $6$ 个基本操作。
以下混用 $="$ 和 $
\approx”$.
测试点1-2
基本操作。
测试点3
实现函数:
\[\begin{equation}\text{cmp}(x) = \begin{cases}-1 & a \lt 0 \\ 0 & a = 0 \\ 1 & a \gt 0\end{cases}\end{equation}\]利用题目中 $\text{Sigmoid}$ 函数的性质:在无穷远处趋近于 $0/1$ 。
容易发现,$s(a«100) = \begin{cases}0 & a<0 \ 1/2 & a = 0 \ 1 & a > 0 \end{cases}$ .
再平移变换即可满足题意。
代码:
def cmp(self):
return self.shl(100).sigmoid().offset(-0.5).shl(1)
测试点4
这个点就很有意思了。
如果直接用测试点3+乘法,只能得到6分。
满分解法是这样的:
我们考虑构造 $\lvert x \rvert$ : $\lvert x \rvert = x - \min\{2x, 0\}$.
$\min\{2x, 0\}$ 怎么实现?
\[min\\{2x, 0\\} = \begin{cases}2x & \text{if }x<0 \\ 0 & \text{otherwise} \end{cases}\]先构造上面式子的第一行。与正负有关,我们想到了构造第3个点的过程。 我们能不能引入某个量,让它在 $x>0$ 的时候能够去掉 $x$ 的贡献呢?怎么去掉贡献?发现 $s(x)$ 在 $x$ 趋近无穷大的时候趋近一个常数,这就是一种信息的丢失。利用这一点构造式子。如果我们要利用这一点,我们就得有方法把值从 $s(x)$ 还原到 $x$。考虑导数的定义:
\[\newcommand{d}{\text{d}} \frac{\d y}{\d x} = \lim_{x \rightarrow x_0} \frac{y-y_0}{x-x_0}\]于是我们可以利用导数来在某个点附近“线性拟合”某函数。对于 $s(x)$ ,不妨在 $x=0$ 处求导,于是有:
\[\newcommand{e}{\text{e}}\frac{\d y}{\d x} = \frac{\e^{-x}}{(1+\e^{-x})^2} \Rightarrow f'(0) = \frac{1}{4}\]这样当 $x$ 接近 $0$ 的时候, $s(x)$ 取值就可以用 $y = \frac{1}{4}x + \frac{1}{2}$ 估计了。
那么开始构造吧:
\[c = s(x<<150)<<152\\ r = s((x>>150)+c) \\ p = ((r-0.5)<<153) - c \\ x = x-p\]这样就可以了。代入任何一个负数/正数发现都满足题意。没有处理 $x=0$ ,因为可以给输入统一偏移一个小常数来避免。卡一卡代码长度,可以令 $p = -p$, 即:
\[p = (((-r)+0.5)<<153)+c\\ x = x+p\]把 $3$ 个减号变成 $1$ 个后就可以拿到满分了。
代码:
def abs(self):
x = self.offset(Decimal('1e-40'))
c = x.shl(150).sigmoid().shl(152)
r = x.shr(150).add(c).sigmoid()
p = r.opposite().offset(Decimal('0.5')).shl(153).add(c)
return x.add(p)
测试点5
bin-to-decimal转换。
直接搞就行了,需要大力卡常,连临时变量都不能用。
代码:
def bcd(t):
for i in range(31):
t[i] = t[i].shl(31-i)
for i in range(1, 32):
t[i] = t[i].add(t[i-1])
return t[31]
bcd([readin() for i in range(32)]).out()
测试点6
decimal-to-bin转换。
同样是直接搞+大力卡常。
卡了3h+常,仍然只有 $8$ 分。原因不明。
代码:
def dcb(a): # 8 pts
ret = list()
a = a.offset(Decimal('1e-40'))
one = a.shl(300).sigmoid()
for i in range(31, 0, -1):
b = a.add(one.shl(i).opposite()).shl(300).sigmoid()
ret.append(b)
a = a.add(b.shl(i).opposite())
ret.append(a)
return ret
BinList = dcb(readin())
for i in BinList:
i.out()
测试点7
同上 $8$ 分。原因同样不明。
按位处理即可。注意,这样实现用的节点数更少:$a \text{ xor } b = a+b-2s((a+b-1.5)«300)$
按照vfk课件里面的说法,第 $6$ 个点写挫了第 $7$ 个点也会挂。好像说中了orz
代码:
def getxor(a, b): # 8 pts
a = a.offset(Decimal('1e-40'))
b = b.offset(Decimal('1e-40'))
one = a.shl(300).sigmoid()
ans = a.shr(300)
for i in range(31, 0, -1):
t = a.add(one.shl(i).opposite()).shl(300).sigmoid()
r = b.add(one.shl(i).opposite()).shl(300).sigmoid()
s = t.add(r)
ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)).shl(i))
a = a.add(t.shl(i).opposite())
b = b.add(r.shl(i).opposite())
s = a.add(b)
ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)))
return ans
getxor(readin(), readin()).out()
测试点8
有意思*2
除以一个一般常数,我没能想出来不用乘法节点的方法。看了题解才知道是再次利用导数来实现线性变换。
我的理解就是,题目虽然给出的是非线性变换,但是在某个点处取极限后,就可以在那个点的邻域看作线性变换了。
找到一个点 $x_0$ ,使得 $s’(x_0) = 0.1$ ,那么在那个点附近函数值可以看成满足直线 $y - s(x_0) = 0.1(x-x_0)$。
在这个点附近做 $x \rightarrow 0.1x$ 的线性变换即可。
那么问题来了:怎么找出一个导数值为 $0.1$ 的点呢?double
精度是不够的,我们需要高精度的浮点数……
等等,用的是python啊,不是自带decimal
么?
于是方法就清晰了:首先求一个满足 $90$ 位精度的 $\e$ (用泰勒展开,$e = 1 + \frac{1}{1!} + \frac{1}{2!} + \frac{1}{3!} + \frac{1}{4!} + \cdots$),然后二分查出一个满足条件的 $x_0$,再求出 $s(x_0)$ 。最后把 $x$ 缩到很小,做上述的线性变换,再放大回原来的倍数。
代码:
def gete(n):
cur = 1
e = Decimal('0')
for i in range(1, n): # e = 1 + 1/1! + 1/2! + 1/3! + 1/4! + ...
e += Decimal('1') / cur
cur *= i
return e
def f(x):
return (1 / (1 + e**(-x)))
def g(x):
return (1 / ((1 + e**(-x))**2)) * e**(-x) - Decimal('0.1')
def g10(): # x satisfying f'(x) = 1/10
l = Decimal('2.0')
r = Decimal('2.1')
for i in range(1000):
mid = (l + r) / 2
if g(mid) < 0:
r = mid
else:
l = mid
return l
def div10(a):
t = g10()
dx = a.shr(100)
x2 = dx.offset(t)
dy = x2.sigmoid().offset(-f(t))
return dy.shl(100)
div10(readin()).out()
测试点9
只适用排序网络,因为没有if语句。
本蒟蒻不会双调排序……不过这个题目冒泡排序足矣。毕竟 $n = 16$ 。
考虑比较器的实现。定义
\[\text{hlp}(x) = \begin{cases} x & \text{ if } x>=0\\0 & \text{otherwise}\end{cases}\]显然 $\text{hlp}(x) = (x+ \lvert x \rvert)»1$.
于是这么构造:
\[x' = x-\text{hlp}(x-y) \\ y' = y + \text{hlp}(x-y)\]则 $x’ = \min\{x, y\}, y’ = \max\{x, y\}$.
再按照冒泡排序摆一堆比较器就行了。
代码:
在Node
类中:
def hlp(self):
return self.add(self.abs()).shr(1)
def comparator(x, y):
d = x.add(y.opposite()).hlp()
return (x.add(d.opposite()), y.add(d))
主程序中:
def bubblesort(nd):
for s in range(15, 0, -1):
for i in range(s):
nd[i], nd[i+1] = nd[i].comparator(nd[i+1]) # 丢一堆比较器
return nd
list(map(Node.out, bubblesort([readin() for i in range(16)]))) # list: 迭代map来解析整个结果列表
测试点10
求 $a \cdot b \text{ mod } m$.
快速乘法。
这个点我是这样做的:我们定义 $\text{hlp}2(x)$:
\[\text{hlp2}(x) = \begin{cases}c & \text{ if } x \leq 0\\ 0 & \text{otherwise}\end{cases}\]这个怎么实现呢?类比第四个测试点,利用上 $\text{cmp}(x)$ .
\[sgn = s[(x-eps)<<150]<<151 \\ t = s[(c>>150)+sgn]\\ ret = [(t-0.5)<<152] - sgn\]$ret$ 即为 $\text{hlp2}(x)$ 的值。
然后再写快速乘法:首先用倍增的方法去掉 $a$ 中的所有 $m$, 使得 $a \in [0, m)$ . 然后每次迭代需要把 $[0, 2m)$ 内的一个数字对 $m$ 取模,这个用上面的函数可以实现。直接这么写的话,共 $2196$ 行,可以获得 $9$ 分。
代码:
在Node
类中:
def hlp2(self, c):
sgn = self.offset(Decimal('-1e-40')).shl(150).sigmoid().shl(151)
t = c.shr(150).add(sgn).sigmoid()
ret = t.offset(Decimal('-0.5')).shl(152).add(sgn.opposite())
return ret
主程序:
def fastmul(a, x, m):
ret = a.shr(150)
one = a.offset(300).sigmoid()
minusone = one.opposite()
minusm = m.opposite()
binx = dcb(x)
for i in range(31, -1, -1):
t = m.shl(i)
d = a.add(t.opposite()).offset(Decimal('1e-40'))
a = a.add(d.opposite().hlp2(t).opposite())
for i in range(31, -1, -1):
ret = ret.add(binx[i].opposite().offset(Decimal('1')).hlp2(a))
ret = ret.add(ret.add(minusm).opposite().hlp2(minusm))
a = a.shl(1)
a = a.add(a.add(minusm).opposite().hlp2(minusm))
return ret
fastmul(readin(), readin(), readin()).out()
代码
放个总的代码:
#!/usr/bin/env python3
from decimal import Decimal
import decimal
decimal.getcontext().prec = 90
line = 1
def PutLine(str):
print(str)
global line
line = line + 1
return line - 1
class Node:
lineno = 0
def __init__(self, lineno):
self.lineno = lineno
def out(self):
Node(PutLine('O {}'.format(self.lineno)))
def shl(self, d):
return Node(PutLine('< {} {}'.format(self.lineno, d)))
def shr(self, d):
return Node(PutLine('> {} {}'.format(self.lineno, d)))
def add(self, y):
return Node(PutLine('+ {} {}'.format(self.lineno, y.lineno)))
def opposite(self):
return Node(PutLine('- {}'.format(self.lineno)))
def sigmoid(self):
return Node(PutLine('S {}'.format(self.lineno)))
def offset(self, c):
return Node(PutLine('C {} {:.90f}'.format(self.lineno, c)))
def cmp(self):
return self.shl(100).sigmoid().offset(-0.5).shl(1)
def abs(self):
x = self.offset(Decimal('1e-40'))
c = x.shl(150).sigmoid().shl(152)
r = x.shr(150).add(c).sigmoid()
p = r.opposite().offset(Decimal('0.5')).shl(153).add(c)
return x.add(p)
def hlp(self):
'''hlp(x) = { x if x>=0, 0 otherwise'''
return self.add(self.abs()).shr(1)
def hlp2(self, c):
'''hlp2(x) = { c if x<=0, 0 otherwise'''
sgn = self.offset(Decimal('-1e-40')).shl(150).sigmoid().shl(151)
t = c.shr(150).add(sgn).sigmoid()
ret = t.offset(Decimal('-0.5')).shl(152).add(sgn.opposite())
return ret
def comparator(x, y):
d = x.add(y.opposite()).hlp()
return (x.add(d.opposite()), y.add(d))
def readin():
return Node(PutLine('I'))
# helper functions
def gete(n):
cur = 1
e = Decimal('0')
for i in range(1, n): # e = 1 + 1/1! + 1/2! + 1/3! + 1/4! + ...
e += Decimal('1') / cur
cur *= i
return e
def f(x):
return (1 / (1 + e**(-x)))
def g(x):
return (1 / ((1 + e**(-x))**2)) * e**(-x) - Decimal('0.1')
def g10(): # x satisfying f'(x) = 1/10
l = Decimal('2.0')
r = Decimal('2.1')
for i in range(1000):
mid = (l + r) / 2
if g(mid) < 0:
r = mid
else:
l = mid
return l
# helper end
def bubblesort(nd):
for s in range(15, 0, -1):
for i in range(s):
nd[i], nd[i+1] = nd[i].comparator(nd[i+1])
return nd
def dcb(a): # 8 pts
ret = list()
a = a.offset(Decimal('1e-40'))
one = a.shl(300).sigmoid()
for i in range(31, 0, -1):
b = a.add(one.shl(i).opposite()).shl(300).sigmoid()
ret.append(b)
a = a.add(b.shl(i).opposite())
ret.append(a)
return ret
def bcd(t):
for i in range(31):
t[i] = t[i].shl(31-i)
for i in range(1, 32):
t[i] = t[i].add(t[i-1])
return t[31]
def getxor(a, b): # 8 pts
a = a.offset(Decimal('1e-40'))
b = b.offset(Decimal('1e-40'))
one = a.shl(300).sigmoid()
ans = a.shr(300)
for i in range(31, 0, -1):
t = a.add(one.shl(i).opposite()).shl(300).sigmoid()
r = b.add(one.shl(i).opposite()).shl(300).sigmoid()
s = t.add(r)
ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)).shl(i))
a = a.add(t.shl(i).opposite())
b = b.add(r.shl(i).opposite())
s = a.add(b)
ans = ans.add(s.add(s.offset(-1.5).shl(300).sigmoid().opposite().shl(1)))
return ans
def div10(a):
t = g10()
dx = a.shr(100)
x2 = dx.offset(t)
dy = x2.sigmoid().offset(-f(t))
return dy.shl(100)
def fastmul(a, x, m):
ret = a.shr(150)
one = a.offset(300).sigmoid()
minusone = one.opposite()
minusm = m.opposite()
binx = dcb(x)
for i in range(31, -1, -1):
t = m.shl(i)
d = a.add(t.opposite()).offset(Decimal('1e-40'))
a = a.add(d.opposite().hlp2(t).opposite())
for i in range(31, -1, -1):
ret = ret.add(binx[i].opposite().offset(Decimal('1')).hlp2(a))
ret = ret.add(ret.add(minusm).opposite().hlp2(minusm))
a = a.shl(1)
a = a.add(a.add(minusm).opposite().hlp2(minusm))
return ret
def main():
global e
e = gete(100) # 标准库中e的精度不够,用泰勒展开算一个达到1e-90精度的
# your solution goes here
pass
if __name__ == '__main__':
main()