This commit is contained in:
Andy Chu 2014-11-11 17:48:52 -08:00
Родитель 860cf98686
Коммит 21ff2461f7
2 изменённых файлов: 93 добавлений и 28 удалений

Просмотреть файл

@ -25,6 +25,44 @@ import sys
import rappor
def SumBits(params, stdin, stdout):
csv_in = csv.reader(stdin)
csv_out = csv.writer(stdout)
num_cohorts = params.num_cohorts
num_bloombits = params.num_bloombits
sums = [[0] * num_bloombits for _ in xrange(num_cohorts)]
num_reports = [0] * num_cohorts
for i, row in enumerate(csv_in):
try:
(user_id, cohort, irr) = row
except ValueError:
raise RuntimeError('Error parsing row %r' % row)
if i == 0:
continue # skip header
cohort = int(cohort)
num_reports[cohort] += 1
if not len(irr) == params.num_bloombits:
raise RuntimeError("Expected %d bits, got %r" % (params.num_bloombits, len(irr)))
for i, c in enumerate(irr):
bit_num = num_bloombits - i - 1 # e.g. char 0 = bit 15, char 15 = bit 0
if c == '1':
sums[cohort][bit_num] += 1
else:
if c != '0':
raise RuntimeError('Invalid IRR -- digits should be 0 or 1')
for cohort in xrange(num_cohorts):
# First column is the total number of reports in the cohort.
row = [num_reports[cohort]] + sums[cohort]
csv_out.writerow(row)
def main(argv):
try:
filename = argv[1]
@ -36,34 +74,7 @@ def main(argv):
except rappor.Error as e:
raise RuntimeError(e)
num_cohorts = params.num_cohorts
num_bloombits = params.num_bloombits
sums = [[0] * num_bloombits for _ in xrange(num_cohorts)]
num_reports = [0] * num_cohorts
csv_in = csv.reader(sys.stdin)
for i, (user_id, cohort, irr) in enumerate(csv_in):
if i == 0:
continue # skip header
cohort = int(cohort)
num_reports[cohort] += 1
assert len(irr) == 16, len(irr)
for i, c in enumerate(irr):
bit_num = num_bloombits - i - 1 # e.g. char 0 = bit 15, char 15 = bit 0
if c == '1':
sums[cohort][bit_num] += 1
else:
if c != '0':
raise RuntimeError('Invalid IRR -- digits should be 0 or 1')
csv_out = csv.writer(sys.stdout)
for cohort in xrange(num_cohorts):
# First column is the total number of reports in the cohort.
row = [num_reports[cohort]] + sums[cohort]
csv_out.writerow(row)
SumBits(params, sys.stdin, sys.stdout)
if __name__ == '__main__':

54
analysis/tools/sum_bits_test.py Executable file
Просмотреть файл

@ -0,0 +1,54 @@
#!/usr/bin/python -S
"""
sum_bits_test.py: Tests for sum_bits.py
"""
import cStringIO
import unittest
import rappor
import sum_bits # module under test
CSV_IN = """\
user_id,cohort,rappor
5,1,0000111100001111
5,1,0000000000111100
"""
# NOTE: bit order is reversed.
EXPECTED_OUT_PREFIX = """\
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\r
2,1,1,2,2,1,1,0,0,1,1,1,1,0,0,0,0\r
"""
TOO_MANY_COLUMNS = """\
user_id,cohort,rappor
5,1,0000111100001111,extra
"""
class SumBitsTest(unittest.TestCase):
def setUp(self):
self.params = rappor.Params()
self.params.num_bloombits = 16
self.params.num_cohorts = 2
def testSum(self):
stdin = cStringIO.StringIO(CSV_IN)
stdout = cStringIO.StringIO()
sum_bits.SumBits(self.params, stdin, stdout)
self.assertMultiLineEqual(EXPECTED_OUT_PREFIX, stdout.getvalue())
def testErrors(self):
stdin = cStringIO.StringIO(TOO_MANY_COLUMNS)
stdout = cStringIO.StringIO()
self.assertRaises(
RuntimeError, sum_bits.SumBits, self.params, stdin, stdout)
if __name__ == '__main__':
unittest.main()