aboutsummaryrefslogtreecommitdiff
path: root/okio/src/jvmMain/kotlin/okio/CipherSource.kt
blob: 154371f3a9bc0211ffca4b167a2d35196ad0a118 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/*
 * Copyright (C) 2020 Square, Inc. and others.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okio

import java.io.IOException
import javax.crypto.Cipher

class CipherSource(
  private val source: BufferedSource,
  val cipher: Cipher
) : Source {
  private val blockSize = cipher.blockSize
  private val buffer = Buffer()
  private var final = false
  private var closed = false

  init {
    // Require block cipher
    require(blockSize > 0) { "Block cipher required $cipher" }
  }

  @Throws(IOException::class)
  override fun read(sink: Buffer, byteCount: Long): Long {
    require(byteCount >= 0L) { "byteCount < 0: $byteCount" }
    check(!closed) { "closed" }
    if (byteCount == 0L) return 0L
    if (final) return buffer.read(sink, byteCount)

    refill()

    return buffer.read(sink, byteCount)
  }

  private fun refill() {
    while (buffer.size == 0L) {
      if (source.exhausted()) {
        final = true
        doFinal()
        break
      } else {
        update()
      }
    }
  }

  private fun update() {
    val head = source.buffer.head!!
    var size = head.limit - head.pos

    // Shorten input until output is guaranteed to fit within a segment
    var outputSize = cipher.getOutputSize(size)
    while (outputSize > Segment.SIZE) {
      check(size > blockSize) { "Unexpected output size $outputSize for input size $size" }
      size -= blockSize
      outputSize = cipher.getOutputSize(size)
    }
    val s = buffer.writableSegment(outputSize)

    val ciphered =
      cipher.update(head.data, head.pos, size, s.data, s.pos)

    source.skip(size.toLong())

    s.limit += ciphered
    buffer.size += ciphered

    // We allocated a tail segment, but didn't end up needing it. Recycle!
    if (s.pos == s.limit) {
      buffer.head = s.pop()
      SegmentPool.recycle(s)
    }
  }

  private fun doFinal() {
    val outputSize = cipher.getOutputSize(0)
    if (outputSize == 0) return

    // For block cipher, output size cannot exceed block size in doFinal.
    val s = buffer.writableSegment(outputSize)

    val ciphered = cipher.doFinal(s.data, s.pos)

    s.limit += ciphered
    buffer.size += ciphered

    // We allocated a tail segment, but didn't end up needing it. Recycle!
    if (s.pos == s.limit) {
      buffer.head = s.pop()
      SegmentPool.recycle(s)
    }
  }

  override fun timeout() = source.timeout()

  @Throws(IOException::class)
  override fun close() {
    closed = true
    source.close()
  }
}