aboutsummaryrefslogtreecommitdiff
path: root/tests/gen_counts.R
blob: 769677c4a12299a4c7f354a4b67273d8f93f0eba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#!/usr/bin/env Rscript
#
# Copyright 2015 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

source('analysis/R/read_input.R')

RandomPartition <- function(total, weights) {
  # Outputs a random partition according to a specified distribution
  # Args:
  #   total - number of samples
  #   weights - weights that are proportional to the probability density
  #              function of the target distribution
  # Returns:
  #   a histogram sampled according to the pdf
  # Example:
  #   > RandomPartition(100, c(3, 2, 1, 0, 1))
  #   [1] 47 24 15  0 14
  if (any(weights < 0))
    stop("Probabilities cannot be negative")

  if (sum(weights) == 0)
    stop("Probabilities cannot sum up to 0")

  bins <- length(weights)
  result <- rep(0, bins)

  # idiomatic way:
  #   rnd_list <- sample(strs, total, replace = TRUE, weights)
  #   apply(as.array(strs), 1, function(x) length(rnd_list[rnd_list == x]))
  #
  # The following is much faster for larger totals. We can replace a loop with
  # (tail) recusion, but R chokes with the recursion depth > 850.

  w <- sum(weights)

  for (i in 1:bins)
    if (total > 0) {  # if total == 0, nothing else to do
      # invariant: w = sum(weights[i:bins])
      # rather than computing sum every time leading to quadratic time, keep
      # updating it

      # The probability p is clamped to [0, 1] to avoid under/overflow errors.
      p <- min(max(weights[i] / w, 0), 1)
      # draw the number of balls falling into the current bin
      rnd_draw <- rbinom(n = 1, size = total, prob = p)
      result[i] <- rnd_draw  # push rnd_draw balls from total to result[i]
      total <- total - rnd_draw
      w <- w - weights[i]
  }

  names(result) <- names(weights)

  return(result)
}

GenerateCounts <- function(params, true_map, partition, reports_per_client) {
  # Fast simulation of the marginal table for RAPPOR reports
  # Args:
  #   params - parameters of the RAPPOR reporting process
  #   true_map - hashed true inputs
  #   partition - allocation of clients between true values
  #   reports_per_client - number of reports (IRRs) per client
  if (nrow(true_map) != (params$m * params$k)) {
    stop(cat("Map does not match the params file!",
                 "mk =", params$m * params$k,
                 "nrow(map):", nrow(true_map),
                 sep = " "))
  }

  # For each reporting type computes its allocation to cohorts.
  # Output is an m x strs matrix.
  cohorts <- as.matrix(
                apply(as.data.frame(partition), 1,
                      function(count) RandomPartition(count, rep(1, params$m))))

  # Expands to (m x k) x strs matrix, where each element (corresponding to the
  # bit in the aggregate Bloom filter) is repeated k times.
  expanded <- apply(cohorts, 2, function(vec) rep(vec, each = params$k))

  # For each bit, the number of clients reporting this bit:
  clients_per_bit <- rep(apply(cohorts, 1, sum), each = params$k)

  # Computes the true number of bits set to one BEFORE PRR.
  true_ones <- apply(expanded * true_map, 1, sum)

  ones_in_prr <-
    unlist(lapply(true_ones,
                  function(x) rbinom(n = 1, size = x, prob = 1 - params$f / 2))) +
    unlist(lapply(clients_per_bit - true_ones,  # clients where the bit is 0
                  function(x) rbinom(n = 1, size = x, prob =  params$f / 2)))

  # Number of IRRs where each bit is reported (either as 0 or as 1)
  reports_per_bit <- clients_per_bit * reports_per_client

  ones_before_irr <- ones_in_prr * reports_per_client

  ones_after_irr <-
    unlist(lapply(ones_before_irr,
                  function(x) rbinom(n = 1, size = x, prob = params$q))) +
    unlist(lapply(reports_per_bit - ones_before_irr,
                  function(x) rbinom(n = 1, size = x, prob = params$p)))

  counts <- cbind(apply(cohorts, 1, sum) * reports_per_client,
        matrix(ones_after_irr, nrow = params$m, ncol = params$k, byrow = TRUE))

  if(any(is.na(counts)))
    stop("Failed to generate bit counts. Likely due to integer overflow.")

  counts
}

ComputePdf <- function(distr, range) {
  # Outputs discrete probability density function for a given distribution

  # These are the five distributions in gen_sim_input.py
  if (distr == 'exp') {
    pdf <- dexp(1:range, rate = 5 / range)
  } else if (distr == 'gauss') {
    half <- range / 2
    left <- -half + 1
    pdf <- dnorm(left : half, sd = range / 6)
  } else if (distr == 'unif') {
    # e.g. for N = 4, weights are [0.25, 0.25, 0.25, 0.25]
    pdf <- dunif(1:range, max = range)
  } else if (distr == 'zipf1') {
    # Since the distrubition defined over a finite set, we allow the parameter
    # of the Zipf distribution to be 1.
    pdf <- sapply(1:range, function(x) 1 / x)
  } else if (distr == 'zipf1.5') {
    pdf <- sapply(1:range, function(x) 1 / x^1.5)
  }
  else {
    stop(sprintf("Invalid distribution '%s'", distr))
  }

  pdf <- pdf / sum(pdf)  # normalize

  pdf
}

# Usage:
#
# $ ./gen_counts.R exp 10000 1 foo_params.csv foo_true_map.csv foo
#
# Inputs:
#   distribution name
#   number of clients
#   reports per client
#   parameters file
#   map file
#   prefix for output files
# Outputs:
#   foo_counts.csv
#   foo_hist.csv
#
# Warning: the number of reports in any cohort must be less than
#          .Machine$integer.max

main <- function(argv) {
  distr <- argv[[1]]
  num_clients <- as.integer(argv[[2]])
  reports_per_client <- as.integer(argv[[3]])
  params_file <- argv[[4]]
  true_map_file <- argv[[5]]
  out_prefix <- argv[[6]]

  params <- ReadParameterFile(params_file)

  true_map <- ReadMapFile(true_map_file, params)

  num_unique_values <- length(true_map$strs)

  pdf <- ComputePdf(distr, num_unique_values)

  # Computes the number of clients reporting each string
  # according to the pre-specified distribution.
  partition <- RandomPartition(num_clients, pdf)

  # Histogram
  true_hist <- data.frame(string = true_map$strs, count = partition)

  counts <- GenerateCounts(params, true_map$map, partition, reports_per_client)

  # Now create a CSV file

  # Opposite of ReadCountsFile in read_input.R
  # http://stackoverflow.com/questions/6750546/export-csv-without-col-names
  counts_path <- paste0(out_prefix, '_counts.csv')
  write.table(counts, file = counts_path,
              row.names = FALSE, col.names = FALSE, sep = ',')
  cat(sprintf('Wrote %s\n', counts_path))

  # TODO: Don't write strings that appear 0 times?
  hist_path <- paste0(out_prefix, '_hist.csv')
  write.csv(true_hist, file = hist_path, row.names = FALSE)
  cat(sprintf('Wrote %s\n', hist_path))
}

if (length(sys.frames()) == 0) {
  main(commandArgs(TRUE))
}