Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsample layer? #67

Open
tirinox opened this issue Aug 12, 2017 · 3 comments
Open

Upsample layer? #67

tirinox opened this issue Aug 12, 2017 · 3 comments

Comments

@tirinox
Copy link

tirinox commented Aug 12, 2017

Hello,
I need upsample layer to implement U-net for segmentation.
Is there one? Will it be implemented?

@bryant1410
Copy link
Member

Hey. We're not planning to implement it in the short term.

But you can give it a try! Check out the code and maybe you run into a way to do it.

@AndreJFBico
Copy link

AndreJFBico commented Aug 23, 2017

Im still testing it but im using the following with MPSCNNUpsamplingNearest.

import Foundation
import MetalPerformanceShaders

open class Upscale: NetworkLayer {
    
    /// Used to determine the filename for this layers weights. (Ignored if there is no ParameterLoader)
    static var weightModifier: String = ""
    
    let size: ConvSize
    private var prevSize: LayerSize!
    
    var upscale: MPSCNNUpsamplingNearest!
    
    public init(size: ConvSize, id: String? = nil) {
        self.size = size
        super.init(id: id)
    }
    
    open override func initialize(network: Network, device: MTLDevice) {
        super.initialize(network: network, device: device)
        let incoming = getIncoming()
        prevSize = incoming[0].outputSize
        outputSize = LayerSize(f: prevSize.f,
                               w: size.strideX * prevSize.w,
                               h: size.strideY * prevSize.h)

        upscale = MPSCNNUpsamplingNearest(device: device, integerScaleFactorX: size.strideX, integerScaleFactorY: size.strideY)
        outputImage = MPSImage(device: device, imageDescriptor: MPSImageDescriptor(layerSize: outputSize))
    }
    
    open override func updatedCheckpoint(device: MTLDevice) {

    }

    open override func execute(commandBuffer: MTLCommandBuffer) {
        let incoming = getIncoming()
        upscale?.encode(commandBuffer: commandBuffer, sourceImage: incoming[0].outputImage, destinationImage: outputImage)
    }
}

@bryant1410
Copy link
Member

Hey @AndreJFBico, you can send a PR with that if you want to have it in Bender!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants